Source code for ngclearn.components.neurons.graded.bernoulliErrorCell

# %%

from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, jit
from ngclearn.utils.model_utils import sigmoid, d_sigmoid

from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment

[docs] class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell """ A simple (non-spiking) Bernoulli error cell - this is a fixed-point solution of a mismatch signal. Specifically, this cell operates as a factorized multivariate Bernoulli distribution. | --- Cell Input Compartments: --- | p - predicted probability (or logits) of positive trial (takes in external signals) | target - desired/goal value (takes in external signals) | modulator - modulation signal (takes in optional external signals) | mask - binary/gating mask to apply to error neuron calculations | --- Cell Output Compartments: --- | L - local loss function embodied by this cell | dp - derivative of L w.r.t. p (or logits, if p = sigmoid(logits)) | dtarget - derivative of L w.r.t. target Args: name: the string name of this cell n_units: number of cellular entities (neural population size) batch_size: batch size dimension of this cell (Default: 1) input_logits: if True, treats compartment `p` as logits and will apply a sigmoidal link, i.e., _p = sigmoid(p), to obtain the param p for Bern(X=1; p) """ def __init__(self, name, n_units, batch_size=1, input_logits=False, shape=None, **kwargs): super().__init__(name, **kwargs) ## Layer Size Setup _shape = (batch_size, n_units) ## default shape is 2D/matrix if shape is None: shape = (n_units,) ## we set shape to be equal to n_units if nothing provided else: _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor self.shape = shape self.n_units = n_units self.batch_size = batch_size self.input_logits = input_logits ## Convolution shape setup self.width = self.height = n_units ## Compartment setup restVals = jnp.zeros(_shape) self.L = Compartment(0., display_name="Bernoulli Log likelihood", units="nats") # loss compartment self.p = Compartment(restVals, display_name="Bernoulli param (prob or logit) for B(X=1; p)") # pos trial prob name. input wire self.dp = Compartment(restVals) # derivative of positive trial prob self.target = Compartment(restVals, display_name="Bernoulli data/target variable") # target. input wire self.dtarget = Compartment(restVals) # derivative target self.modulator = Compartment(restVals + 1.0) # to be set/consumed self.mask = Compartment(restVals + 1.0) # @transition(output_compartments=["dp", "dtarget", "L", "mask"])
[docs] @compilable def advance_state(self, dt): ## compute Bernoulli error cell output # Get the variables p = self.p.get() target = self.target.get() modulator = self.modulator.get() mask = self.mask.get() # Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit # behavior of the local cost functional eps = 0.0001 _p = p if self.input_logits: ## convert from "logits" to probs via sigmoidal link function _p = sigmoid(p) _p = jnp.clip(_p, eps, 1. - eps) ## post-process to prevent div by 0 x = target #sum_x = jnp.sum(x) ## Sum^N_{n=1} x_n (n is n-th datapoint) #sum_1mx = jnp.sum(1. - x) ## Sum^N_{n=1} (1 - x_n) one_min_p = 1. - _p one_min_x = 1. - x log_p = jnp.log(_p) ## ln(p) log_one_min_p = jnp.log(one_min_p) ## ln(1 - p) L = jnp.sum(log_p * x + log_one_min_p * one_min_x) ## Bern LL if self.input_logits: dL_dp = x - _p ## d(Bern LL)/dp where _p = sigmoid(p) else: dL_dp = x/(_p) - one_min_x/one_min_p ## d(Bern LL)/dp dL_dp = dL_dp * d_sigmoid(p) dL_dx = (log_p - log_one_min_p) ## d(Bern LL)/dx dp = dL_dp dp = dp * modulator * mask ## NOTE: how does mask apply to a multivariate Bernoulli? dtarget = dL_dx * modulator * mask mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t # Set state # dp, dtarget, jnp.squeeze(L), mask self.dp.set(dp) self.dtarget.set(dtarget) self.L.set(jnp.squeeze(L)) self.mask.set(mask)
[docs] @compilable def reset(self): ## reset core components/statistics self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
[docs] @compilable def batched_reset(self, batch_size): _shape = (batch_size, self.shape[0]) if len(self.shape) > 1: _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) ## "rest"/reset values dp = restVals dtarget = restVals target = restVals p = restVals modulator = restVals + 1. ## reset modulator signal L = 0. #jnp.zeros((1, 1)) ## rest loss mask = jnp.ones(_shape) ## reset mask # Set compartment self.dp.set(dp) self.dtarget.set(dtarget) self.target.set(target) self.p.set(p) self.modulator.set(modulator) self.L.set(L) self.mask.set(mask)
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "GaussianErrorcell - computes mismatch/error signals at " "each time step t (between a `target` and a prediction `mu`)" } compartment_props = { "inputs": {"p": "External input positive probability value(s)", "target": "External input target signal value(s)", "modulator": "External input modulatory/scaling signal(s)", "mask": "External binary/gating mask to apply to signals"}, "outputs": {"L": "Local loss value computed/embodied by this error-cell", "dp": "first derivative of loss w.r.t. positive probability value(s)", "dtarget": "first derivative of loss w.r.t. target value(s)"}, } hyperparams = { "n_units": "Number of neuronal cells to model in this layer", "batch_size": "Batch size dimension of this component", "sigma": "External input variance value (currently fixed and not learnable)" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "Bernoulli(x=target; p) where target is binary variable", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = BernoulliErrorCell("X", 9) print(X)