from ngcsimlib.component import Component
from jax import numpy as jnp, random, jit
from functools import partial
import time, sys
#@partial(jit, static_argnums=[3])
[docs]
def run_cell(dt, targ, mu, eType="gaussian"):
"""
Moves cell dynamics one step forward.
Args:
dt: integration time constant
targ: target pattern value
mu: prediction value
Returns:
derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, local loss
"""
return run_laplacian_cell(dt, targ, mu)
@jit
def run_laplacian_cell(dt, targ, mu):
"""
Moves Laplacian cell dynamics one step forward. Specifically, this
routine emulates the error unit behavior of the local cost functional:
| L(targ, mu) = -||targ - mu||_1
| or log likelihood of the Laplace distribution with identity scale
Args:
dt: integration time constant
targ: target pattern value
mu: prediction value
Returns:
derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, loss
"""
dmu = jnp.sign(targ - mu) # e (error unit)
dtarg = -dmu # reverse of e
L = -jnp.sum(jnp.abs(dmu)) # technically, this is mean absolute error
return dmu, dtarg, L
[docs]
class LaplacianErrorCell(Component): ## Rate-coded/real-valued error unit/cell
"""
A simple (non-spiking) Laplacian error cell - this is a fixed-point solution
of a mismatch signal.
Args:
name: the string name of this cell
n_units: number of cellular entities (neural population size)
tau_m: (Unused -- currently cell is a fixed-point model)
leakRate: (Unused -- currently cell is a fixed-point model)
key: PRNG Key to control determinism of any underlying synapses
associated with this cell
useVerboseDict: triggers slower, verbose dictionary mode (Default: False)
"""
## Class Methods for Compartment Names
[docs]
@classmethod
def lossName(cls):
return "L"
[docs]
@classmethod
def outputCompartmentName(cls):
return 'e' ## rate-coded output
[docs]
@classmethod
def meanName(cls):
return 'mu'
[docs]
@classmethod
def derivMeanName(cls):
return 'dmu'
[docs]
@classmethod
def targetName(cls):
return 'target'
[docs]
@classmethod
def derivTargetName(cls):
return 'dtarget'
[docs]
@classmethod
def modulatorName(cls):
return 'modulator'
## Bind Properties to Compartments for ease of use
@property
def loss(self):
return self.compartments.get(self.lossName(), None)
@loss.setter
def loss(self, inp):
self.compartments[self.lossName()] = inp
@property
def mean(self):
return self.compartments.get(self.meanName(), None)
@mean.setter
def mean(self, inp):
self.compartments[self.meanName()] = inp
@property
def derivMean(self):
return self.compartments.get(self.derivMeanName(), None)
@derivMean.setter
def derivMean(self, inp):
self.compartments[self.derivMeanName()] = inp
@property
def target(self):
return self.compartments.get(self.targetName(), None)
@target.setter
def target(self, inp):
self.compartments[self.targetName()] = inp
@property
def derivTarget(self):
return self.compartments.get(self.derivTargetName(), None)
@derivTarget.setter
def derivTarget(self, inp):
self.compartments[self.derivTargetName()] = inp
@property
def modulator(self):
return self.compartments.get(self.modulatorName(), None)
@modulator.setter
def modulator(self, inp):
self.compartments[self.modulatorName()] = inp
# Define Functions
def __init__(self, name, n_units, tau_m=0., leakRate=0., key=None,
useVerboseDict=False, **kwargs):
super().__init__(name, useVerboseDict, **kwargs)
##Random Number Set up
self.key = key
if self.key is None:
self.key = random.PRNGKey(time.time_ns())
##Layer Size Setup
self.n_units = n_units
self.batch_size = 1
self.reset()
[docs]
def verify_connections(self):
self.metadata.check_incoming_connections(self.meanName(), min_connections=1)
self.metadata.check_incoming_connections(self.targetName(), min_connections=1)
[docs]
def advance_state(self, t, dt, **kwargs):
## compute Laplacian/MAE error cell output
self.derivMean, self.derivTarget, self.loss = \
run_cell(dt, self.target, self.mean)
if self.modulator is not None:
self.derivMean = self.derivMean * self.modulator
self.derivTarget = self.derivTarget * self.modulator
self.modulator = None ## use and consume modulator
[docs]
def reset(self, **kwargs):
self.derivMean = jnp.zeros((self.batch_size, self.n_units))
self.derivTarget = jnp.zeros((self.batch_size, self.n_units))
self.target = jnp.zeros((self.batch_size, self.n_units)) #None
self.mean = jnp.zeros((self.batch_size, self.n_units)) #None
self.modulator = None
[docs]
def save(self, **kwargs):
pass