Source code for ngclearn.modules.regression.elastic_net

import numpy as np
from ngclearn.utils.distribution_generator import DistributionGenerator as dist
from ngclearn import numpy as jnp

from jax import numpy as jnp, random, jit
from ngclearn import Context, MethodProcess
from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse
from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
from ngcsimlib.global_state import stateManager

[docs] class Iterative_ElasticNet(): """ A neural circuit implementation of the iterative Elastic Net (L1 and L2) algorithm using a Hebbian learning update rule. The circuit implements sparse regression through Hebbian synapses with Elastic Net regularization. The specific differential equation that characterizes this model is dW_reg (for adjusting W, given dW (the gradient of loss/energy function), it adds lmbda * dW_reg to the dW) | dW_reg = (jnp.sign(W) * l1_ratio) + (W * (1-l1_ratio)/2) | dW/dt = dW + lmbda * dW_reg | --- Circuit Components: --- | W - HebbianSynapse for learning regularized dictionary weights | err - GaussianErrorCell for computing prediction errors | --- Component Compartments --- | W.inputs - input features (takes in external signals) | W.pre - pre-synaptic activity for Hebbian learning | W.post - post-synaptic error signals | W.weights - learned dictionary coefficients | err.mu - predicted outputs | err.target - target signals (target vector) | err.dmu - error gradients | err.L - loss/energy values Args: key: JAX PRNG key for random number generation name: string name for this solver sys_dim: dimensionality of the system/target space dict_dim: dimensionality of the dictionary/feature space/the number of predictors batch_size: number of samples to process in parallel weight_fill: initial constant value to fill weight matrix with (Default: 0.05) lr: learning rate for synaptic weight updates (Default: 0.01) lmbda: elastic net regularization lambda parameter (Default: 0.0001) optim_type: optimization type for updating weights; supported values are "sgd" and "adam" (Default: "adam") threshold: minimum absolute coefficient value - values below this are set to zero during thresholding (Default: 0.001) epochs: number of training epochs (Default: 100) """ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, lr=0.01, lmbda = 0.0001, l1_ratio=0.5, optim_type="adam", threshold=0.05, epochs=100): key, *subkeys = random.split(key, 10) ## synaptic plasticity properties and characteristics self.T = 100 self.dt = 1 self.epochs = epochs self.weight_fill = weight_fill self.threshold = threshold self.name = name self.lr = lr feature_dim = dict_dim with Context(self.name) as self.circuit: self.W = HebbianSynapse( "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1, weight_init=dist.constant(value=weight_fill), prior=('elastic_net', (lmbda, l1_ratio)), w_bound=0., optim_type=optim_type, key=subkeys[0] ) self.err = GaussianErrorCell("err", n_units=sys_dim) # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.batch_size = batch_size self.err.batch_size = batch_size # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.outputs >> self.err.mu self.err.dmu >> self.W.post # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ advance = (MethodProcess(name="advance_state") >> self.W.advance_state >> self.err.advance_state) self.advance = advance evolve = (MethodProcess(name="evolve") >> self.W.evolve) self.evolve = evolve reset = (MethodProcess(name="reset") >> self.err.reset >> self.W.reset) self.reset = reset
[docs] def batch_set(self, batch_size): self.W.batch_size = batch_size self.err.batch_size = batch_size
[docs] def clamp(self, y_scaled, X): self.W.inputs.set(X) self.W.pre.set(X) self.err.target.set(y_scaled)
[docs] def thresholding(self, scale=1.): coef_old = self.coef_ new_coeff = jnp.where(jnp.abs(coef_old) >= self.threshold, coef_old, 0.) self.coef_ = new_coeff * scale self.W.weights.set(new_coeff) return self.coef_, coef_old
[docs] def fit(self, y, X): self.reset.run() self.clamp(y_scaled=y, X=X) for i in range(self.epochs): inputs = jnp.array(self.advance.pack_rows(self.T, t=lambda x: x, dt=self.dt)) stateManager.state, outputs = self.advance.scan(inputs) self.evolve.run(t=self.T, dt=self.dt) self.coef_ = np.array(self.W.weights.get()) return self.coef_, self.err.mu.get(), self.err.L.get()