Source code for ngclearn.components.input_encoders.ganglionCell

from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random
from ngclearn import compilable
from ngclearn import Compartment
import jax
from typing import Union, Tuple

def _create_gaussian_filter(patch_shape, sigma):
    ## Create a 2D Gaussian kernel centered on patch_shape with given sigma.
    px, py = patch_shape

    x_ = jnp.linspace(0, px - 1, px)
    y_ = jnp.linspace(0, py - 1, py)

    x, y = jnp.meshgrid(x_, y_)

    xc = px // 2
    yc = py // 2

    filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2)))
    return filter / jnp.sum(filter)

def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1):
    g1 = _create_gaussian_filter(patch_shape, sigma=sigma)
    g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k)

    dog = g1 - lmbda * g2

    return dog #- jnp.mean(dog)

def _create_patches(obs, patch_shape, step_shape):
    """
    Extract 2D patches from a batch of images using a sliding window.

    Args:
        obs: Input array (B, ix, iy)

        patch_shape: Patch size (px, py)

        step_shape: Stride (sx, sy) -- use 0 for full-overlap

    Returns:
        Patches array (B, n_cells, px, py)

    """

    B, ix, iy = obs.shape
    px, py = patch_shape
    sx, sy = step_shape

    if sx == 0:
        n_x = ix // px
    else:
        n_x = (ix - px) // sx + 1

    if sy == 0:
        n_y = iy // py
    else:
        n_y = (iy - py) // sy + 1

    patches = jnp.stack([
        obs[:,
            i * sx:i * sx + px, j * sy:j * sy + py
            ] for i in range(n_x)
              for j in range(n_y)
    ], axis=1)

    return patches


[docs] class RetinalGanglionCell(JaxComponent): """ A group of retinal ganglion cell that sense input stimuli and send out filtered signals (as output). Note that these simulated cells employ internal generalized filters based on either Gaussian or difference-of-Gaussian kernels) to recover historical receptive field processing effects. | --- Cell Input Compartments: --- | inputs - input (takes in external signals) | --- Cell State Compartments: --- | filter - filter (function applied to input) | --- Cell Output Compartments: --- | outputs - output Args: name: the string name of this cell filter_type: string name of filter function (Default: identity) :Note: supported filters include "gaussian", "difference_of_gaussian" sigma: standard deviation of (gaussian) kernel area_shape: shape of receptive field area of ganglion cells in this module (all together) n_cells: number of ganglion cells in this module patch_shape: shape of each ganglion cell's receptive field area step_shape: the non-overlapping area between each pair (two) of ganglion cells batch_size: batch size dimension of this cell/module (Default: 1) """ def __init__( self, name: str, filter_type: str, area_shape: Tuple[int, int], n_cells: int, patch_shape: Tuple[int, int], step_shape: Tuple[int, int], batch_size: int = 1, sigma: float = 1.0, key: Union[jax.Array, None] = None, **kwargs ): super().__init__(name=name, key=key) ## Layer Size Setup self.filter_type = filter_type self.n_cells = n_cells self.sigma = sigma self.batch_size = batch_size self.area_shape = area_shape self.patch_shape = patch_shape self.step_shape = step_shape filter = jnp.ones(self.patch_shape) if filter_type == 'gaussian': filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma) elif filter_type == 'difference_of_gaussian': filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma) # ═════════════════ compartments initial values ════════════════════ in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) self.n_cells * self.patch_shape[0] * self.patch_shape[1])) # ═══════════════════ set compartments ══════════════════════ self.inputs = Compartment(in_restVals, display_name="Input Stimulus") # input compartment self.filter = Compartment(filter, display_name="Filter") # Filter compartment self.outputs = Compartment(out_restVals, display_name="Output Signal") # output compartment
[docs] @compilable def advance_state(self, t): inputs = self.inputs.get() _filter = self.filter.get() px, py = self.patch_shape # ═══════════════════ extract pathches for filters ══════════════════ input_patches = _create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape) # ═══════════════════ apply filter to all pathches ══════════════════ filtered_input = input_patches * _filter ## shape: (B | n_cells | px | py) # ════════════ reshape all cells responses to a single input to brain ════════════ filtered_input = filtered_input.reshape(-1, self.n_cells * (px * py)) ## shape: (B | n_cells * px * py) # ═══════════════════ normalize filtered signals ══════════════════ outputs = filtered_input - jnp.mean(filtered_input, axis=1, keepdims=True) ## shape: (B | n_cells * px * py) self.outputs.set(outputs)
[docs] @compilable def reset(self): ## reset core components/statistics # self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member in_restVals = jnp.zeros((self.batch_size, *self.area_shape)) ## input: (B | ix | iy) out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py) self.n_cells * self.patch_shape[0] * self.patch_shape[1])) self.inputs.set(in_restVals) self.outputs.set(out_restVals)
# Viet: NOTE: we should not need this function since the reset function # one could set the batch size then do reset # @compilable # def batched_reset(self, batch_size): # in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) # out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) # self.n_cells * self.patch_shape[0] * self.patch_shape[1])) # self.inputs.set(in_restVals) # self.outputs.set(out_restVals)
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "RetinalGanglionCell - filters the input stimuli according retinal ganglion dynamics" } compartment_props = { "inputs": {"inputs": "Takes in external input signal values"}, "states": {"filter": "Preprocessing function applies to input)"}, "outputs": {"outputs": "Preprocessed signal values emitted at time t"}, } hyperparams = { "filter_type": "Type of the filter for preprocessing the input", "sigma": "Standard deviation of gaussian kernel/filter", "area_shape": "Effective receptive field area shape of ganglion cells in this module", "n_cells": "Number of retinal ganglion (center-surround) cells to model in this layer", "patch_shape": "Classical receptive field area shape of individual ganglion cells in this module", "step_shape": "Extra-classical receptive field area shape each ganglion cell in this module", "batch_size": "Batch size dimension of this component" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "~ Gaussian(x)", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = RetinalGanglionCell( "RGC", filter_type="gaussian", sigma=2.3, area_shape=(16, 26), n_cells = 3, patch_shape=(16, 16), step_shape=(0, 5) ) print(X)