Source code for ngclearn.components.neurons.graded.gaussianErrorCell

# %%

from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, jit
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment

[docs] class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell """ A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically, this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood). | --- Cell Input Compartments: --- | mu - predicted value (takes in external signals) | Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it's sigma^2 | target - desired/goal value (takes in external signals) | modulator - modulation signal (takes in optional external signals) | mask - binary/gating mask to apply to error neuron calculations | --- Cell Output Compartments: --- | L - local loss function embodied by this cell | dmu - derivative of L w.r.t. mu | dSigma - derivative of L w.r.t. Sigma | dtarget - derivative of L w.r.t. target Args: name: the string name of this cell n_units: number of cellular entities (neural population size) batch_size: batch size dimension of this cell (Default: 1) sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution; Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses to a constant/fixed `sigma^2`, i.e., Sigma = sigma^2, where `sigma` is a scalar standard deviation argument (Default: 1) """ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs): super().__init__(name, **kwargs) ## Layer Size Setup _shape = (batch_size, n_units) ## default shape is 2D/matrix if shape is None: shape = (n_units,) ## we set shape to be equal to n_units if nothing provided else: _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor sigma_shape = (1,1) if not isinstance(sigma, float) and not isinstance(sigma, int): sigma_shape = jnp.array(sigma).shape self.sigma_shape = sigma_shape self.shape = shape self.batch_size = batch_size self.n_units = n_units ## Convolution shape setup self.width = self.height = n_units ## Compartment setup restVals = jnp.zeros(_shape) self.L = Compartment(0., display_name="Gaussian Log likelihood", units="nats") # loss compartment self.mu = Compartment(restVals, display_name="Gaussian mean") # mean/mean name. input wire self.dmu = Compartment(restVals) # derivative mean _Sigma = jnp.zeros(sigma_shape) self.Sigma = Compartment(_Sigma + sigma, display_name="Gaussian variance/covariance") self.dSigma = Compartment(_Sigma) self.target = Compartment(restVals, display_name="Gaussian data/target variable") # target. input wire self.dtarget = Compartment(restVals) # derivative target self.modulator = Compartment(restVals + 1.0) # to be set/consumed self.mask = Compartment(restVals + 1.0) @staticmethod def _eval_log_density(target, mu, Sigma): ## Gaussian log likelihood ## NOTE: ln(p) = -(x - mu)^2 * 1/(2 Sigma), where Sigma might be sigma^2 or covariance matrix _dmu = (target - mu) #_numerator = 1. # 0.5 log_density = -jnp.sum(jnp.square(_dmu)) * (1./((Sigma ** 2) * 2)) #* (_numerator / Sigma) return log_density, _dmu ## return density and raw delta
[docs] @compilable def advance_state(self, dt): ## compute Gaussian error cell output (fixed-point) # Get the variables mu = self.mu.get() target = self.target.get() Sigma = self.Sigma.get() modulator = self.modulator.get() mask = self.mask.get() # Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit # behavior of the local cost functional: # FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2 # but should support full log likelihood of the multivariate Gaussian with covariance of different types # TODO: could introduce a variant of GaussianErrorCell that moves according to an ODE # (using integration time constant dt) L, _dmu = GaussianErrorCell._eval_log_density(target, mu, Sigma) # L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) ## _dmu => "raw" e (error unit/mis-match) # _dmu = (target - mu) dmu = _dmu * (1./ Sigma) ## obtain precision-scaled e: (target - mu)/Sigma dtarget = -dmu # reverse of e ## -(target - mu)/Sigma dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for Sigma dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance... dtarget = dtarget * modulator * mask mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t # Update compartments self.dmu.set(dmu) self.dtarget.set(dtarget) self.dSigma.set(dSigma) self.L.set(jnp.squeeze(L)) self.mask.set(mask)
[docs] @compilable def reset(self): ## reset core components/statistics self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
[docs] @compilable def batched_reset(self, batch_size): _shape = (batch_size, self.shape[0]) if len(self.shape) > 1: _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) dmu = restVals dtarget = restVals dSigma = jnp.zeros(self.sigma_shape) target = restVals mu = restVals modulator = mu + 1. L = 0. #jnp.zeros((1, 1)) mask = jnp.ones(_shape) self.dmu.set(dmu) self.dtarget.set(dtarget) self.dSigma.set(dSigma) if not self.target.targeted: self.target.set(target) if not self.mu.targeted: self.mu.set(mu) self.modulator.set(modulator) self.L.set(L) self.mask.set(mask)
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "GaussianErrorCell - computes mismatch/error signals at " "each time step t (between a `target` and a prediction `mu`)" } compartment_props = { "inputs": {"mu": "External input prediction value(s)", "Sigma": "External variance/covariance prediction value(s)", "target": "External input target signal value(s)", "modulator": "External input modulatory/scaling signal(s)", "mask": "External binary/gating mask to apply to signals"}, "outputs": {"L": "Local loss / free-energy value embodied by this error-cell", "dmu": "first derivative of loss w.r.t. prediction value(s)", "dSigma": "first derivative of loss w.r.t. variance/covariance value(s)", "dtarget": "first derivative of loss w.r.t. target value(s)"}, } hyperparams = { "n_units": "Number of neuronal cells to model in this layer", "batch_size": "Batch size dimension of this component", "sigma": "External input variance value (currently fixed and not learnable)" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "Gaussian(x=target; mu, sigma)", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = GaussianErrorCell("X", 9) print(X)