from jax import random, numpy as jnp, jit
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngcsimlib.logger import info
from ngclearn.utils.distribution_generator import DistributionGenerator
from ngclearn.components.synapses.convolution.ngcconv import conv2d
from ngclearn.components.jaxComponent import JaxComponent
[docs]
class ConvSynapse(JaxComponent): ## base-level convolutional cable
"""
A base convolutional synaptic cable.
| --- Synapse Compartments: ---
| inputs - input (takes in external signals)
| outputs - output signals
| filters - current value tensor of filter/kernel efficacies
| biases - current base-rate/bias efficacies
Args:
name: the string name of this cell
x_shape: 2d shape of input map signal (component currently assumess a square input maps)
shape: tuple specifying shape of this synaptic cable (usually a 4-tuple
with number `filter height x filter width x input channels x number output channels`);
note that currently filters/kernels are assumed to be square
(kernel.width = kernel.height)
filter_init: a kernel to drive initialization of this synaptic cable's
filter values
bias_init: kernel to drive initialization of bias/base-rate values
(Default: None, which turns off/disables biases)
stride: length/size of stride
padding: pre-operator padding to use -- "VALID" (none), "SAME"
resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((K @ in) * resist_scale) + b
where `@` denotes convolution
batch_size: batch size dimension of this component
"""
def __init__(
self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1.,
batch_size=1, **kwargs
):
super().__init__(name, **kwargs)
self.filter_init = filter_init
self.bias_init = bias_init
## Synapse meta-parameters
self.shape = shape ## shape of synaptic filter tensor
x_size, x_size = x_shape
self.x_size = x_size
self.resist_scale = resist_scale ## post-transformation scale factor
self.padding = padding
self.stride = stride
####################### Set up padding arguments #######################
k_size, k_size, n_in_chan, n_out_chan = shape
self.pad_args = None
if self.padding is not None and self.padding == "SAME":
if x_size % stride == 0:
pad_along_height = max(k_size - stride, 0)
else:
pad_along_height = max(k_size - (x_size % stride), 0)
pad_bottom = pad_along_height // 2
pad_top = pad_along_height - pad_bottom
pad_left = pad_bottom
pad_right = pad_top
self.pad_args = ((pad_bottom, pad_top), (pad_left, pad_right))
if self.padding is not None and self.padding == "VALID":
self.pad_args = ((0, 0), (0, 0))
######################### set up compartments ##########################
tmp_key, *subkeys = random.split(self.key.get(), 4)
#weights = dist.initialize_params(subkeys[0], filter_init, shape)
if self.filter_init is None:
info(self.name, "is using default weight initializer!")
self.filter_init = DistributionGenerator.uniform(0.025, 0.8)
weights = self.filter_init(shape, subkeys[0]) ## filter tensor
self.batch_size = batch_size # 1
## Compartment setup and shape computation
_x = jnp.zeros((self.batch_size, x_size, x_size, n_in_chan))
_d = conv2d(_x, weights, stride_size=stride, padding=padding) * 0
self.in_shape = _x.shape
self.out_shape = _d.shape
self.inputs = Compartment(jnp.zeros(self.in_shape))
self.outputs = Compartment(jnp.zeros(self.out_shape))
self.weights = Compartment(weights)
if self.bias_init is None:
info(self.name, "is using default bias value of zero (no bias kernel provided)!")
self.biases = Compartment(
#dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
self.bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0
)
[docs]
@compilable
def advance_state(self): #Rscale, padding, stride, weights, biases, inputs):
_x = self.inputs.get()
## FIXME: does resist_scale affect update rules?
outputs = conv2d(
_x, self.weights.get(), stride_size=self.stride, padding=self.padding
) * self.resist_scale + self.biases.get()
self.outputs.set(outputs)
[docs]
@compilable
def reset(self): #in_shape, out_shape):
preVals = jnp.zeros(self.in_shape)
postVals = jnp.zeros(self.out_shape)
self.inputs.set(preVals)
self.outputs.set(postVals)
# def save(self, directory, **kwargs):
# file_name = directory + "/" + self.name + ".npz"
# if self.bias_init != None:
# jnp.savez(file_name, weights=self.weights.get(),
# biases=self.biases.get())
# else:
# jnp.savez(file_name, weights=self.weights.get())
#
# def load(self, directory, **kwargs):
# file_name = directory + "/" + self.name + ".npz"
# data = jnp.load(file_name)
# self.weights.set(data['weights'])
# if "biases" in data.keys():
# self.biases.set(data['biases'])
[docs]
@classmethod
def help(cls): ## component help function
properties = {
"synapse type": "ConvSynapse - performs a synaptic convolution (@) of inputs "
"to produce output signals"
}
compartment_props = {
"inputs":
{"inputs": "Takes in external input signal values"},
"states":
{"filters": "Synaptic filter parameter values",
"biases": "Base-rate/bias parameter values",
"key": "JAX PRNG key"},
"outputs":
{"outputs": "Output of synaptic transformation"},
}
hyperparams = {
"shape": "Shape of synaptic filter value matrix; `kernel width` x `kernel height` "
"x `number input channels` x `number output channels`",
"x_shape": "Shape of any single incoming/input feature map",
"filter_init": "Initialization conditions for synaptic filter (K) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level output scaling factor (R)",
"stride": "length / size of stride",
"padding": "pre-operator padding to use, i.e., `VALID` `SAME`"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "outputs = [K @ inputs] * R + b",
"hyperparameters": hyperparams}
return info