Source code for ngclearn.utils.analysis.linear_probe

import jax
import numpy as np
from ngclearn.utils.analysis.probe import Probe
from ngclearn.utils.model_utils import drop_out, softmax, layer_normalize
from jax import jit, random, numpy as jnp, lax, nn
from functools import partial as bind
from ngclearn.utils.distribution_generator import DistributionGenerator
from ngclearn.utils.optim import adam, sgd

[docs] @bind(jax.jit, static_argnums=[2, 3]) def run_linear_probe(params, x, use_softmax=False, use_LN=False): Wln_mu, Wln_scale, W, b = params _x = x if use_LN: ## normalize input vector to probe predictor _x = layer_normalize(_x, Wln_mu, Wln_scale) y_mu = (jnp.matmul(_x, W) + b) if use_softmax: y_mu = softmax(y_mu) return y_mu
[docs] @bind(jax.jit, static_argnums=[3, 4]) def eval_linear_probe(params, x, y, use_softmax=True, use_LN=False): y_mu = run_linear_probe(params, x, use_softmax=use_softmax, use_LN=use_LN) e = y_mu - y if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions L = -jnp.mean(jnp.sum(jnp.log(y_mu) * y, axis=1, keepdims=True)) else: ## MSE for real-valued outputs L = jnp.sum(jnp.square(e)) * 1./x.shape[0] return L, y_mu
#return y_mu, L, e # @bind(jax.jit, static_argnums=[6, 7]) # def calc_linear_probe_grad(x, y, params, eta, decay=0., l1_decay=0., use_softmax=False, use_LN=False): # y_mu, L, e = eval_linear_probe(params, x, y, use_softmax=use_softmax, use_LN=use_LN) # Wln_mu, Wln_scale, W, b = params # dW = jnp.matmul(x.T, e) + W * decay/eta + jnp.abs(W) * 0.5 * l1_decay/eta # db = jnp.sum(e, axis=0, keepdims=True) # dW = dW * (1. / x.shape[0]) # db = db * (1. / x.shape[0]) # return y_mu, L, [dW, db] # @jit # def update_linear_probe(x, y, params, eta, decay=0., l1_decay=0., use_softmax=False): # y_mu, L, e = run_linear_probe(x, params, use_softmax=use_softmax) # W, b = params # dW = jnp.matmul(x.T, e) # db = jnp.sum(e, axis=0, keepdims=True) # W = W - dW * eta/x.shape[0] - W * decay/x.shape[0] - jnp.abs(W) * 0.5 * l1_decay/x.shape[0] # b = b - db * eta/x.shape[0] # return y_mu, L, [W, b]
[docs] class LinearProbe(Probe): """ This implements a regularized linear probe, which is useful for evaluating the quality of encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot encodings or real-valued vector regression targets). Note that this probe allows for configurable Elastic-net (L1+L2) regularization. Args: dkey: init seed key source_seq_length: length of input sequence (e.g., height x width of the image feature) input_dim: input dimensionality of probe out_dim: output dimensionality of probe batch_size: size of batches to process per internal call to update (or process) use_LN: should layer normalization be used on incoming input vectors given to this probe? use_softmax: should a softmax be applied to output of probe or not? """ def __init__( self, dkey, source_seq_length, input_dim, out_dim, batch_size=1, use_LN=False, use_softmax=False, **kwargs ): super().__init__(dkey, batch_size, **kwargs) self.dkey, *subkeys = random.split(self.dkey, 3) self.source_seq_length = source_seq_length self.input_dim = input_dim self.out_dim = out_dim self.use_softmax = use_softmax self.use_LN = use_LN self.l2_decay = 0.0001 self.l1_decay = 0.000025 # eta = 0.05 for SGD, batch_size=2000 ## set up classifier flat_input_dim = input_dim * source_seq_length weight_init = DistributionGenerator.fan_in_gaussian() #dist.fan_in_gaussian() # dist.gaussian(mu=0., sigma=0.05) # 0.02) Wln_mu = jnp.zeros((1, flat_input_dim)) Wln_scale = jnp.ones((1, flat_input_dim)) W = weight_init((flat_input_dim, out_dim), subkeys[0]) #dist.initialize_params(subkeys[0], weight_init, (flat_input_dim, out_dim)) b = jnp.zeros((1, out_dim)) self.probe_params = [Wln_mu, Wln_scale, W, b] ## set up update rule/optimizer ## set up gradient calculator self.grad_fx = jax.value_and_grad(eval_linear_probe, argnums=0, has_aux=True) self.optim_params = adam.adam_init(self.probe_params) self.eta = 0.001
[docs] def process(self, embeddings, dkey=None): _embeddings = embeddings if len(_embeddings.shape) > 2: ## we flatten a sequence batch to 2D for a linear probe flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) outs = run_linear_probe(self.probe_params, _embeddings, use_softmax=self.use_softmax, use_LN=self.use_LN) return outs
[docs] def update(self, embeddings, labels, dkey=None): _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) ## compute adjustments to probe parameters # predictions, loss, grads = calc_linear_probe_grad( # self.probe_params, _embeddings, labels, self.eta, decay=self.l2_decay, l1_decay=self.l1_decay, # use_softmax=self.use_softmax, use_LN=self.use_LN # ) outputs, grads = self.grad_fx( self.probe_params, _embeddings, labels, use_softmax=self.use_softmax, use_LN=self.use_LN ) loss, predictions = outputs ## adjust parameters of probe self.optim_params, self.probe_params = adam.adam_step( self.optim_params, self.probe_params, grads, eta=self.eta ) return loss, predictions