ngclearn.components.neurons.graded package

Submodules

ngclearn.components.neurons.graded.bernoulliErrorCell module

class ngclearn.components.neurons.graded.bernoulliErrorCell.BernoulliErrorCell(*args, **kwargs)[source]

Bases: JaxComponent

A simple (non-spiking) Bernoulli error cell - this is a fixed-point solution of a mismatch signal. Specifically, this cell operates as a factorized multivariate Bernoulli distribution.

— Cell Input Compartments: —
p - predicted probability (or logits) of positive trial (takes in external signals)
target - desired/goal value (takes in external signals)
modulator - modulation signal (takes in optional external signals)
mask - binary/gating mask to apply to error neuron calculations
— Cell Output Compartments: —
L - local loss function embodied by this cell
dp - derivative of L w.r.t. p (or logits, if p = sigmoid(logits))
dtarget - derivative of L w.r.t. target
Parameters:
  • name – the string name of this cell

  • n_units – number of cellular entities (neural population size)

  • batch_size – batch size dimension of this cell (Default: 1)

  • input_logits – if True, treats compartment p as logits and will apply a sigmoidal link, i.e., _p = sigmoid(p), to obtain the param p for Bern(X=1; p)

advance_state(dt)[source]
batched_reset(batch_size)[source]
classmethod help()[source]
reset()[source]

ngclearn.components.neurons.graded.gaussianErrorCell module

class ngclearn.components.neurons.graded.gaussianErrorCell.GaussianErrorCell(*args, **kwargs)[source]

Bases: JaxComponent

A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically, this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood).

— Cell Input Compartments: —
mu - predicted value (takes in external signals)
Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it’s sigma^2
target - desired/goal value (takes in external signals)
modulator - modulation signal (takes in optional external signals)
mask - binary/gating mask to apply to error neuron calculations
— Cell Output Compartments: —
L - local loss function embodied by this cell
dmu - derivative of L w.r.t. mu
dSigma - derivative of L w.r.t. Sigma
dtarget - derivative of L w.r.t. target
Parameters:
  • name – the string name of this cell

  • n_units – number of cellular entities (neural population size)

  • batch_size – batch size dimension of this cell (Default: 1)

  • sigma – initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution; Note that if the compartment Sigma is never used, then this cell assumes that the covariance collapses to a constant/fixed sigma^2, i.e., Sigma = sigma^2, where sigma is a scalar standard deviation argument (Default: 1)

advance_state(dt)[source]
batched_reset(batch_size)[source]
classmethod help()[source]
reset()[source]

ngclearn.components.neurons.graded.laplacianErrorCell module

class ngclearn.components.neurons.graded.laplacianErrorCell.LaplacianErrorCell(*args, **kwargs)[source]

Bases: JaxComponent

A simple (non-spiking) Laplacian error cell - this is a fixed-point solution of a mismatch/error signal.

— Cell Input Compartments: —
shift - predicted shift value (takes in external signals)
Scale - predicted scale (takes in external signals)
target - desired/goal value (takes in external signals)
modulator - modulation signal (takes in optional external signals)
mask - binary/gating mask to apply to error neuron calculations
— Cell Output Compartments: —
L - local loss function embodied by this cell
dshift - derivative of L w.r.t. shift
dScale - derivative of L w.r.t. Scale
dtarget - derivative of L w.r.t. target
Parameters:
  • name – the string name of this cell

  • n_units – number of cellular entities (neural population size)

  • batch_size – batch size dimension of this cell (Default: 1)

  • scale – initial/fixed value for prediction scale matrix in multivariate laplacian distribution; Note that if the compartment Scale is never used, then this cell assumes that the scale collapses to a constant/fixed scale

advance_state(dt)[source]
batched_reset(batch_size)[source]
classmethod help()[source]
reset()[source]

ngclearn.components.neurons.graded.leakyNoiseCell module

class ngclearn.components.neurons.graded.leakyNoiseCell.LeakyNoiseCell(*args, **kwargs)[source]

Bases: JaxComponent

A non-spiking cell driven by the gradient dynamics entailed by a continuous-time noisy, leaky recurrent state.

Reference: https://pmc.ncbi.nlm.nih.gov/articles/PMC4771709/

The specific differential equation that characterizes this cell is (for adjusting x) is:

tau_x * dx/dt = -x + j_rec + j_in + sqrt(2 alpha (sigma_pre)^2) * eps; and,
r = f(x) + (eps * sigma_post).
where j_in is the set of incoming input signals
and j_rec is the set of recurrent input signals
and eps is a sample of unit Gaussian noise, i.e., eps ~ N(0, 1)
and f(x) is the rectification function
and sigma_pre is the pre-rectification noise applied to membrane x
and sigma_post is the post-rectification noise applied to rates f(x)
— Cell Input Compartments: —
j_input - input (bottom-up) electrical/stimulus current (takes in external signals)
j_recurrent - recurrent electrical/stimulus pressure
— Cell State Compartments —
x - noisy rate activity / current value of state
— Cell Output Compartments: —
r - post-rectified activity, e.g., fx(x) = relu(x)
r_prime - post-rectified temporal derivative, e.g., dfx(x) = d_relu(x)
Parameters:
  • name – the string name of this cell

  • n_units – number of cellular entities (neural population size)

  • tau_x – state membrane time constant (milliseconds)

  • act_fx – rectification function (Default: “relu”)

  • output_scale – factor to multiply output of nonlinearity of this cell by (Default: 1.)

  • integration_type

    type of integration to use for this cell’s dynamics; current supported forms include “euler” (Euler/RK-1 integration) and “midpoint” or “rk2” (midpoint method/RK-2 integration) (Default: “euler”)

    Note:

    setting the integration type to the midpoint method will increase the accuracy of the estimate of the cell’s evolution at an increase in computational cost (and simulation time)

  • sigma_pre – pre-rectification noise scaling factor / standard deviation (Default: 0.1)

  • sigma_post – post-rectification noise scaling factor / standard deviation (Default: 0.)

  • leak_scale – degree to which membrane leak should be scaled (Default: 1)

advance_state(t, dt)[source]
batched_reset(batch_size)[source]
classmethod help()[source]
reset()[source]

ngclearn.components.neurons.graded.rateCell module

class ngclearn.components.neurons.graded.rateCell.RateCell(*args, **kwargs)[source]

Bases: JaxComponent

A non-spiking cell driven by the gradient dynamics of neural generative coding-driven predictive processing.

The specific differential equation that characterizes this cell is (for adjusting v, given current j, over time) is:

tau_m * dz/dt = lambda * prior(z) + (j + j_td)
where j is the set of general incoming input signals (e.g., message-passed signals)
and j_td is taken to be the set of top-down pressure signals
— Cell Input Compartments: —
j - input pressure (takes in external signals)
j_td - input/top-down pressure input (takes in external signals)
— Cell State Compartments —
z - rate activity
— Cell Output Compartments: —
zF - post-activation function activity, i.e., fx(z)
Parameters:
  • name – the string name of this cell

  • n_units – number of cellular entities (neural population size)

  • tau_m – membrane/state time constant (milliseconds)

  • prior

    a kernel for specifying the type of centered scale-shift distribution to impose over neuronal dynamics, applied to each neuron or dimension within this component (Default: (“gaussian”, 0)); this is a tuple with 1st element containing a string name of the distribution one wants to use while the second value is a leak rate scalar that controls the influence/weighting that this distribution has on the dynamics; for example, (“laplacian, 0.001”) means that a centered laplacian distribution scaled by 0.001 will be injected into this cell’s dynamics ODE each step of simulated time

    Note:

    supported scale-shift distributions include “laplacian”, “cauchy”, “exp”, and “gaussian”

  • act_fx – string name of activation function/nonlinearity to use

  • output_scale – factor to multiply output of nonlinearity of this cell by (Default: 1.)

  • integration_type

    type of integration to use for this cell’s dynamics; current supported forms include “euler” (Euler/RK-1 integration) and “midpoint” or “rk2” (midpoint method/RK-2 integration) (Default: “euler”)

    Note:

    setting the integration type to the midpoint method will increase the accuray of the estimate of the cell’s evolution at an increase in computational cost (and simulation time)

  • resist_scale – a scaling factor applied to incoming pressure j (default: 1)

advance_state(dt)[source]
batched_reset(batch_size)[source]
classmethod help()[source]
reset()[source]

ngclearn.components.neurons.graded.rewardErrorCell module

class ngclearn.components.neurons.graded.rewardErrorCell.RewardErrorCell(*args, **kwargs)[source]

Bases: JaxComponent

A reward prediction error (RPE) cell.

— Cell Input Compartments: —
reward - current reward signal at time t
accum_reward - current accumulated episodic reward signal
— Cell Output Compartments: —
mu - current moving average prediction of reward at time t
rpe - current reward prediction error (RPE) signal
accum_reward - current accumulated episodic reward signal (IF online predictor not used)
Parameters:
  • name – the string name of this cell

  • n_units – number of cellular entities (neural population size)

  • alpha – decay factor to apply to (exponential) moving average prediction

  • ema_window_len – exponential moving average window length – for use only in evolve step for updating episodic reward signals; (default: 10)

  • use_online_predictor – use online prediction of reward signal (default: True) – if set to False, then reward prediction will only occur upon a call to this cell’s evolve function

advance_state(dt)[source]
evolve(dt)[source]
classmethod help()[source]
reset()[source]

Module contents