ngclearn.components.neurons.graded package
Submodules
ngclearn.components.neurons.graded.bernoulliErrorCell module
- class ngclearn.components.neurons.graded.bernoulliErrorCell.BernoulliErrorCell(*args, **kwargs)[source]
Bases:
JaxComponentA simple (non-spiking) Bernoulli error cell - this is a fixed-point solution of a mismatch signal. Specifically, this cell operates as a factorized multivariate Bernoulli distribution.
— Cell Input Compartments: —p - predicted probability (or logits) of positive trial (takes in external signals)target - desired/goal value (takes in external signals)modulator - modulation signal (takes in optional external signals)mask - binary/gating mask to apply to error neuron calculations— Cell Output Compartments: —L - local loss function embodied by this celldp - derivative of L w.r.t. p (or logits, if p = sigmoid(logits))dtarget - derivative of L w.r.t. target- Parameters:
name – the string name of this cell
n_units – number of cellular entities (neural population size)
batch_size – batch size dimension of this cell (Default: 1)
input_logits – if True, treats compartment p as logits and will apply a sigmoidal link, i.e., _p = sigmoid(p), to obtain the param p for Bern(X=1; p)
ngclearn.components.neurons.graded.gaussianErrorCell module
- class ngclearn.components.neurons.graded.gaussianErrorCell.GaussianErrorCell(*args, **kwargs)[source]
Bases:
JaxComponentA simple (non-spiking) Gaussian error cell - this is a fixed-point solution of a mismatch signal.
— Cell Input Compartments: —mu - predicted value (takes in external signals)Sigma - predicted covariance (takes in external signals)target - desired/goal value (takes in external signals)modulator - modulation signal (takes in optional external signals)mask - binary/gating mask to apply to error neuron calculations— Cell Output Compartments: —L - local loss function embodied by this celldmu - derivative of L w.r.t. mudSigma - derivative of L w.r.t. Sigmadtarget - derivative of L w.r.t. target- Parameters:
name – the string name of this cell
n_units – number of cellular entities (neural population size)
batch_size – batch size dimension of this cell (Default: 1)
sigma – initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution; Note that if the compartment Sigma is never used, then this cell assumes that the covariance collapses to a constant/fixed sigma
ngclearn.components.neurons.graded.laplacianErrorCell module
- class ngclearn.components.neurons.graded.laplacianErrorCell.LaplacianErrorCell(*args, **kwargs)[source]
Bases:
JaxComponentA simple (non-spiking) Laplacian error cell - this is a fixed-point solution of a mismatch/error signal.
— Cell Input Compartments: —shift - predicted shift value (takes in external signals)Scale - predicted scale (takes in external signals)target - desired/goal value (takes in external signals)modulator - modulation signal (takes in optional external signals)mask - binary/gating mask to apply to error neuron calculations— Cell Output Compartments: —L - local loss function embodied by this celldshift - derivative of L w.r.t. shiftdScale - derivative of L w.r.t. Scaledtarget - derivative of L w.r.t. target- Parameters:
name – the string name of this cell
n_units – number of cellular entities (neural population size)
batch_size – batch size dimension of this cell (Default: 1)
scale – initial/fixed value for prediction scale matrix in multivariate laplacian distribution; Note that if the compartment Scale is never used, then this cell assumes that the scale collapses to a constant/fixed scale
ngclearn.components.neurons.graded.leakyNoiseCell module
- class ngclearn.components.neurons.graded.leakyNoiseCell.LeakyNoiseCell(*args, **kwargs)[source]
Bases:
JaxComponentA non-spiking cell driven by the gradient dynamics entailed by a continuous-time noisy, leaky recurrent state.
Reference: https://pmc.ncbi.nlm.nih.gov/articles/PMC4771709/
The specific differential equation that characterizes this cell is (for adjusting x) is:
tau_x * dx/dt = -x + j_rec + j_in + sqrt(2 alpha (sigma_rec)^2) * epswhere j_in is the set of incoming input signalsand j_rec is the set of recurrent input signalsand eps is a sample of unit Gaussian noise, i.e., eps ~ N(0, 1)— Cell Input Compartments: —j_input - input (bottom-up) electrical/stimulus current (takes in external signals)j_recurrent - recurrent electrical/stimulus pressure— Cell State Compartments —x - noisy rate activity / current value of state— Cell Output Compartments: —r - post-rectified activity, i.e., fx(x) = relu(x)- Parameters:
name – the string name of this cell
n_units – number of cellular entities (neural population size)
tau_x – state membrane time constant (milliseconds)
act_fx – rectification function (Default: “relu”)
output_scale – factor to multiply output of nonlinearity of this cell by (Default: 1.)
integration_type –
type of integration to use for this cell’s dynamics; current supported forms include “euler” (Euler/RK-1 integration) and “midpoint” or “rk2” (midpoint method/RK-2 integration) (Default: “euler”)
- Note:
setting the integration type to the midpoint method will increase the accuracy of the estimate of the cell’s evolution at an increase in computational cost (and simulation time)
sigma_rec – noise scaling factor / standard deviation (Default: 1)
ngclearn.components.neurons.graded.rateCell module
- class ngclearn.components.neurons.graded.rateCell.RateCell(*args, **kwargs)[source]
Bases:
JaxComponentA non-spiking cell driven by the gradient dynamics of neural generative coding-driven predictive processing.
The specific differential equation that characterizes this cell is (for adjusting v, given current j, over time) is:
tau_m * dz/dt = lambda * prior(z) + (j + j_td)where j is the set of general incoming input signals (e.g., message-passed signals)and j_td is taken to be the set of top-down pressure signals— Cell Input Compartments: —j - input pressure (takes in external signals)j_td - input/top-down pressure input (takes in external signals)— Cell State Compartments —z - rate activity— Cell Output Compartments: —zF - post-activation function activity, i.e., fx(z)- Parameters:
name – the string name of this cell
n_units – number of cellular entities (neural population size)
tau_m – membrane/state time constant (milliseconds)
prior –
a kernel for specifying the type of centered scale-shift distribution to impose over neuronal dynamics, applied to each neuron or dimension within this component (Default: (“gaussian”, 0)); this is a tuple with 1st element containing a string name of the distribution one wants to use while the second value is a leak rate scalar that controls the influence/weighting that this distribution has on the dynamics; for example, (“laplacian, 0.001”) means that a centered laplacian distribution scaled by 0.001 will be injected into this cell’s dynamics ODE each step of simulated time
- Note:
supported scale-shift distributions include “laplacian”, “cauchy”, “exp”, and “gaussian”
act_fx – string name of activation function/nonlinearity to use
output_scale – factor to multiply output of nonlinearity of this cell by (Default: 1.)
integration_type –
type of integration to use for this cell’s dynamics; current supported forms include “euler” (Euler/RK-1 integration) and “midpoint” or “rk2” (midpoint method/RK-2 integration) (Default: “euler”)
- Note:
setting the integration type to the midpoint method will increase the accuray of the estimate of the cell’s evolution at an increase in computational cost (and simulation time)
resist_scale – a scaling factor applied to incoming pressure j (default: 1)
ngclearn.components.neurons.graded.rewardErrorCell module
- class ngclearn.components.neurons.graded.rewardErrorCell.RewardErrorCell(*args, **kwargs)[source]
Bases:
JaxComponentA reward prediction error (RPE) cell.
— Cell Input Compartments: —reward - current reward signal at time taccum_reward - current accumulated episodic reward signal— Cell Output Compartments: —mu - current moving average prediction of reward at time trpe - current reward prediction error (RPE) signalaccum_reward - current accumulated episodic reward signal (IF online predictor not used)- Parameters:
name – the string name of this cell
n_units – number of cellular entities (neural population size)
alpha – decay factor to apply to (exponential) moving average prediction
ema_window_len – exponential moving average window length – for use only in evolve step for updating episodic reward signals; (default: 10)
use_online_predictor – use online prediction of reward signal (default: True) – if set to False, then reward prediction will only occur upon a call to this cell’s evolve function