ngclearn.utils.analysis package

Submodules

ngclearn.utils.analysis.attentive_probe module

class ngclearn.utils.analysis.attentive_probe.AttentiveProbe(dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64, target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32, use_LN=True, use_LN_input=False, use_softmax=True, dropout=0.5, eta=0.0002, eta_decay=0.0, min_eta=1e-05, **kwargs)[source]

Bases: Probe

This implements a nonlinear attentive probe, which is useful for evaluating the quality of encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot encodings or real-valued vector regression targets).

Parameters:
  • dkey – init seed key

  • source_seq_length – length of input sequence (e.g., height x width of the image feature)

  • input_dim – input dimensionality of probe

  • out_dim – output dimensionality of probe

  • num_heads – number of cross-attention heads

  • head_dim – output dimensionality of each cross-attention head

  • target_seq_length – to pool, we set it at one (or map the source sequence to the target sequence of length 1)

  • learnable_query_dim – target sequence dim (output dimension of cross-attention portion of probe)

  • batch_size – size of batches to process per internal call to update (or process)

  • hid_dim – dimensionality of hidden layer(s) of MLP portion of probe

  • use_LN – should layer normalization be used within MLP portions of probe or not?

  • use_softmax – should a softmax be applied to output of probe or not?

process(embeddings, dkey=None)[source]

Runs the probe’s inference scheme given an input batch of sequences of encodings/embeddings.

Parameters:
  • embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)

  • dkey – Optional JAX noise key

Returns:

probe output scores/probability values

update(embeddings, labels, dkey=None)[source]

Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally assigned labels/target vector values.

Parameters:
  • embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)

  • labels – target values that map to embedding sequence; shape (B, target_value_dim)

  • dkey – Optional JAX noise key

Returns:

probe output scores/probability values

ngclearn.utils.analysis.attentive_probe.cross_attention(dkey, params: tuple, x1: Array, x2: Array, mask: Array, n_heads: int = 8, dropout_rate: float = 0.0) Array[source]

Run cross-attention function given a list of parameters and two sequences (x1 and x2). The function takes in a query sequence x1 and a key-value sequence x2, and returns an output of the same shape as x1. T is the length of the query sequence, and S is the length of the key-value sequence. Dq is the dimension of the query sequence, and Dkv is the dimension of the key-value sequence. H is the number of attention heads.

Parameters:
  • dkey – JAX key to trigger any internal noise (drop-out)

  • params (tuple) – tuple of parameters

  • x1 (jax.Array) – query sequence. Shape: (B, T, Dq)

  • x2 (jax.Array) – key-value sequence. Shape: (B, S, Dkv)

  • mask (jax.Array) – mask tensor. Shape: (B, T, S)

  • n_heads (int, optional) – number of attention heads. Defaults to 8.

  • dropout_rate (float, optional) – dropout rate. Defaults to 0.0.

Returns:

output of cross-attention

Return type:

jax.Array

ngclearn.utils.analysis.attentive_probe.eval_attention_probe(dkey, params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True)[source]

Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned labels/regression targets.

Parameters:
  • dkey – JAX key to trigger any internal noise (as in drop-out)

  • params – parameters tuple/list of probe

  • encodings – input encoding vectors/data

  • labels – output target values (e.g., labels, regression target vectors)

  • mask – optional mask to be applied to internal cross-attention

  • n_heads – number of attention heads

  • dropout – if >0, triggers drop-out applied internally to cross-attention

  • use_LN – use layer normalization?

  • use_softmax – should softmax be applied to output of attention probe? (useful for classification)

Returns:

current loss value, output scores/probabilities

ngclearn.utils.analysis.attentive_probe.masked_fill(x: Array, mask: Array, value=0) Array[source]

Return an output with masked condition, with non-masked value be the other value

Parameters:
  • x (jax.Array) – _description_

  • mask (jax.Array) – _description_

  • value (int, optional) – _description_. Defaults to 0.

Returns:

_description_

Return type:

jax.Array

ngclearn.utils.analysis.attentive_probe.run_attention_probe(dkey, params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True)[source]

Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).

Parameters:
  • dkey – JAX key for any internal noise to be applied

  • params – parameters tuple/list of probe

  • encodings – input encoding vectors/data

  • mask – optional mask to be applied to internal cross-attention

  • n_heads – number of attention heads

  • dropout – if >0, triggers drop-out applied internally to cross-attention

  • use_LN – use layer normalization?

  • use_LN_input – use layer normalization on input encodings?

  • use_softmax – should softmax be applied to output of attention probe? (useful for classification)

Returns:

output scores/probabilities, cross-attention (hidden) features

ngclearn.utils.analysis.linear_probe module

class ngclearn.utils.analysis.linear_probe.LinearProbe(dkey, source_seq_length, input_dim, out_dim, batch_size=1, use_LN=False, use_softmax=False, **kwargs)[source]

Bases: Probe

This implements a regularized linear probe, which is useful for evaluating the quality of encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot encodings or real-valued vector regression targets). Note that this probe allows for configurable Elastic-net (L1+L2) regularization.

Parameters:
  • dkey – init seed key

  • source_seq_length – length of input sequence (e.g., height x width of the image feature)

  • input_dim – input dimensionality of probe

  • out_dim – output dimensionality of probe

  • batch_size – size of batches to process per internal call to update (or process)

  • use_LN – should layer normalization be used on incoming input vectors given to this probe?

  • use_softmax – should a softmax be applied to output of probe or not?

process(embeddings, dkey=None)[source]

Runs the probe’s inference scheme given an input batch of sequences of encodings/embeddings.

Parameters:
  • embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)

  • dkey – Optional JAX noise key

Returns:

probe output scores/probability values

update(embeddings, labels, dkey=None)[source]

Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally assigned labels/target vector values.

Parameters:
  • embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)

  • labels – target values that map to embedding sequence; shape (B, target_value_dim)

  • dkey – Optional JAX noise key

Returns:

probe output scores/probability values

ngclearn.utils.analysis.linear_probe.eval_linear_probe(params, x, y, use_softmax=True, use_LN=False)[source]
ngclearn.utils.analysis.linear_probe.run_linear_probe(params, x, use_softmax=False, use_LN=False)[source]

ngclearn.utils.analysis.probe module

class ngclearn.utils.analysis.probe.Probe(dkey, batch_size=1, dev_batch_size=1, **kwargs)[source]

Bases: object

General framework for an analysis probe (that may or may not be learnable in an iterative fashion).

Parameters:
  • dkey – init seed key

  • batch_size – size of batches to process per internal call to update (or process)

fit(dataset, dev_dataset=None, n_iter=50, patience=20)[source]

Fits this probe to a pool of data.

Parameters:
  • dataset – a dataset tuple containing two design tensors/matrices (X, Y), with the first containing encoding vector sequences of shape (N, T, embed_dim) or (N, embed_dim) and the second containing the corresponding labels/targets for the embedding data of shape (N, target_dim); (Default: None)

  • dev_dataset – an optional development set tuple, with same format as dataset (Default: None)

  • n_iter – number of iterations to run model fitting (Default: 50 iterations)

  • patience – number of iterations of improvement (decrease) in loss before early-stopping enacted

Returns:

best accuracy found over fitting run

predict(data, batch_size=None)[source]

Runs this probe’s inference scheme over a pool of data.

Parameters:
  • data – a dataset or design tensor/matrix containing encoding vector sequences; shape (N, T, embed_dim) or (N, embed_dim)

  • batch_size – optional batch-size argument (Default: None, will use training batch size)

Returns:

the output scores/predictions made by this probe

process(embeddings, dkey=None)[source]

Runs the probe’s inference scheme given an input batch of sequences of encodings/embeddings.

Parameters:
  • embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)

  • dkey – Optional JAX noise key

Returns:

probe output scores/probability values

update(embeddings, labels, dkey=None)[source]

Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally assigned labels/target vector values.

Parameters:
  • embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)

  • labels – target values that map to embedding sequence; shape (B, target_value_dim)

  • dkey – Optional JAX noise key

Returns:

probe output scores/probability values

Module contents