Source code for ngclearn.components.synapses.competitive.SOMSynapse

import jax
from jax import random, numpy as jnp, jit
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.utils.model_utils import softmax

from ngclearn.components.synapses.denseSynapse import DenseSynapse

def _gaussian_kernel(dist, sigma): ## Gaussian neighborhood function
    density = jnp.exp(-jnp.power(dist, 2) / (2 * (sigma ** 2)))  # n_units x 1
    return density

def _ricker_marr_kernel(dist, sigma): ## mexican hat neighborhood function
    ## can reform Ricker-Marr in terms of a function of a Gaussian density
    gauss_density = _gaussian_kernel(dist, sigma)
    density = gauss_density * (1. - (jnp.power(dist, 2) / (sigma ** 2)))
    # NOTE: Since the mexican hat density can produce negative values,
    #       we clip to 0 to avoid this as negative density messes up SOM learning
    return jnp.maximum(density, 0.)

def _euclidean_dist(a: jax.Array, b: jax.Array):
    """ Compute batch-wise Euclidean (L2) distance between two sets of vectors a and b

    Args:
        a (jax.Array): (batch_size, n_inputs)
        b (jax.Array): (n_inputs, n_units_xy) n_units here is height * width of SOM topology

    Returns:
        d (jax.Array): (batch_size, n_units_xy) distance of each input pattern to each SOM unit
        delta (jax.Array): (batch_size, n_inputs, n_units_xy) raw differences between each input pattern and each SOM unit
    """
    # (B, I, 1) - (1, I, U) -> (B, I, U)
    delta = jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=0)
    # (B, U) # norm across n_inputs dimension
    d = jnp.linalg.norm(delta, axis=1)
    return d, delta

def _manhattan_dist(a: jax.Array, b:jax.Array):
    """Manhattan (L1) distance

    Args:
        a (jax.Array): (batch_size, n_inputs)
        b (jax.Array): (n_inputs, n_units_xy) n_units here is height * width of SOM topology

    Returns:
        d (jax.Array): (batch_size, n_units_xy) distance of each input pattern to each SOM unit
        delta (jax.Array): (batch_size, n_inputs, n_units_xy) raw differences between each input pattern and each SOM unit
    """
    # (B, I, 1) - (1, I, U) -> (B, I, U)
    delta = jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=0)
    # (B, U) # norm across n_inputs dimension
    d = jnp.linalg.norm(delta, ord=1, axis=1)
    return d, delta

def _cosine_dist(a: jax.Array, b: jax.Array):
    """Cosine-similarity distance

    Args:
        a (jax.Array): (batch_size, n_inputs)
        b (jax.Array): (n_inputs, n_units_xy) n_units here is height * width of SOM topology

    Returns:
        d (jax.Array): (batch_size, n_units_xy) distance of each input pattern to each SOM unit
        delta (jax.Array): (batch_size, n_inputs, n_units_xy) raw differences between each input pattern and each SOM unit
    """
    # (B, I, 1) - (1, I, U) -> (B, I, U)
    delta = jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=0)

    # Viet: Original code
    # d = 1. - (jnp.matmul(a.T, b) / (jnp.linalg.norm(a, axis=0) * jnp.linalg.norm(b, axis=0)))

    # Viet: new code for cosine similarity distance (similar code but more readable)
    a_norm = jnp.linalg.norm(a, axis=1, keepdims=True) # (B, 1)
    b_norm = jnp.linalg.norm(b, axis=0, keepdims=True) # (1, U)
    # (I, U) / (B, 1) * (1, U) = (B, U) = (B, I)
    cosine_similarity = a @ b / (a_norm * b_norm) # (B, U)
    d = 1. - cosine_similarity # convert similarity to distance
    return d, delta


[docs] class SOMSynapse(DenseSynapse): # Self-organizing map (SOM) synaptic cable """ A synaptic cable that emulates a self-organizing map (or Kohonen map) that is adapted via competitive Hebbian learning. Many of this synapses internal compartments house dynamically-updated values for learning elements such as the SOM's neighborhood radius and learning rate. Mathematically, a synaptic update performed according to SOM theory is: | Delta W_{ij} = (x.T - W) * n(BMU) * eta | where n(BMU) is a neighborhood weighting function centered around (topological) coordinates of BMU | where x is vector of pre-synaptic inputs, W is SOM's synaptic matrix, and BMU is best-matching unit for x | --- Synapse Compartments: --- | inputs - input (takes in external signals) | outputs - output signals (transformation induced by synapses) | weights - current value matrix of synaptic efficacies | bmu - current best-matching unit (BMU), based on current inputs | delta - current differences between inputs and each weight vector of this SOM's synaptic matrix | i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`) | eta - current learning rate value | radius - current radius value to control neighborhood function | key - JAX PRNG key | --- Synaptic Plasticity Compartments: --- | inputs - pre-synaptic signal/value to drive 1st term of SOM update (x) | outputs - post-synaptic signal/value to drive 2nd term of SOM update (y) | neighbor_weights - topology weighting applied to synaptic adjustments | dWeights - current delta matrix containing changes to be applied to synapses | References: | Kohonen, Teuvo. "The self-organizing map." Proceedings of the IEEE 78.9 (2002): 1464-1480. Args: name: the string name of this cell n_inputs: number of input units to this SOM n_units_x: number of output units along length of rectangular topology of this SOM n_units_y: number of output units along width of rectangular topology of this SOM eta: (initial) learning rate / step-size for this SOM (initial condition value for `eta`) distance_function: string specifying distance function to use for finding best-matching units (BMUs) (Default: "euclidean"). usage guide: "euclidean" = use L2 / Euclidean distance "manhattan" = use L1 / Manhattan / taxi-cab distance "cosine" = use cosine-similarity distance neighbor_function: string specifying neighborhood function to compute approximate topology weighting across units in topology (based on BMU) (Default: "gaussian"). usage guide: "gaussian" = use Gaussian kernel "ricker" = use Mexican-hat / Ricker-Marr kernel weight_init: a kernel to drive initialization of this synaptic cable's values; typically a tuple with 1st element as a string calling the name of initialization to use resist_scale: a fixed scaling factor to apply to synaptic transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in) p_conn: probability of a connection existing (default: 1.); setting this to < 1. will result in a sparser synaptic structure """ def __init__( self, name, n_inputs, n_units_x, ## num units along width of SOM rectangular topology n_units_y, ## num units along length of SOM rectangular topology eta=0.5, ## learning rate distance_function="euclidean", neighbor_function="gaussian", weight_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs ): shape = (n_inputs, n_units_x * n_units_y) super().__init__( name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs ) ### build (rectangular) topology coordinates # NOTE: Viet: We might want to use np.meshgrid here instead of the for-loop approach to build coordinates # for performance reasons. But for now, we can keep it as is for readability and clarity. # We can optimize later if needed. # Shape: (n_units_x * n_units_y, 2) coords = [] for i in range(n_units_x): x = jnp.ones((n_units_x, 1)) * i y = jnp.expand_dims(jnp.arange(start=0, stop=n_units_y), axis=1) xy = jnp.concat((x, y), axis=1) coords.append(xy) self.coords = jnp.concat(coords, axis=0) ### Synapse and SOM hyper-parameters #self.radius = radius self.distance_function = distance_function self.dist_fx = 0 ## default := 0 (euclidean) if "manhattan" in distance_function: self.dist_fx = 1 elif "cosine" in distance_function: self.dist_fx = 2 self.neighbor_function = neighbor_function self.neighbor_fx = 0 ## default := 0 (Gaussian) if "ricker" in neighbor_function: self.neighbor_fx = 1 ## Mexican-hat function self.shape = shape ## shape of synaptic efficacy matrix ## exponential decay -> dz/dt = -kz has sol'n: z0 exp(-k t) # self.iterations = 50000 # self.initial_eta = eta ## alpha (in SOM-lingo) #0.5 # self.initial_radius = jnp.maximum(n_units_x, n_units_y) / 2 #n_units_x / 2 # self.C = self.iterations / jnp.log(self.initial_radius) ## exponential decay -> dz/dt = -kz has sol'n: z0 exp(-k t) self.initial_eta = eta ## alpha (in SOM-lingo) #0.5 self.initial_radius = jnp.maximum(n_units_x, n_units_y) / 2 self.tau_eta = 50000 self.tau_radius = self.tau_eta / jnp.log(self.initial_radius) ## C ## SOM Compartment setup self.radius = Compartment(jnp.zeros((1, 1)) + self.initial_radius) self.eta = Compartment(jnp.zeros((1, 1)) + self.initial_eta) self.i_tick = Compartment(jnp.zeros((1, 1))) # Viet: batch-aware setup self.bmu = Compartment(jnp.zeros((self.batch_size, 1), dtype=jnp.int32)) self.delta = Compartment(jnp.zeros((self.batch_size, shape[0], shape[1]))) self.neighbor_weights = Compartment(jnp.zeros((self.batch_size, shape[1]))) self.dWeights = Compartment(self.weights.get() * 0) def _calc_bmu(self): """obtain index of best-matching unit (BMU) This will compute the distance between x (batch_size, n_inputs) and W (n_inputs, n_units_xy) Returns: jax.Array: bmu and delta """ x = self.inputs.get() W = self.weights.get() if self.dist_fx == 1: ## L1 distance d, delta = _manhattan_dist(x, W) elif self.dist_fx == 2: ## cosine distance d, delta = _cosine_dist(x, W) else: ## L2 distance d, delta = _euclidean_dist(x, W) # Viet: BMU has to have shape (batch_size, 1) return jnp.argmin(d, axis=1, keepdims=True), delta def _calc_neighborhood_weights(self): ## neighborhood function # bmu is not a vector of indices, each one is best-matching unit for each sample in batch bmu = self.bmu.get().reshape(-1) ## get best-matching unit per sample, flatten to (batch_size,) coords = self.coords ## constant coordinate array radius = self.radius.get() ## get current neighborhood radius value coord_bmu = coords[bmu, :] # (B, 2) get coordinates of BMU for each sample in batch # (1, n_units_xy, 2) - (B, 1, 2) -> (B, n_units_xy, 2) # get delta between coordinates of each SOM unit and BMU delta = jnp.expand_dims(coords, axis=0) - jnp.expand_dims(coord_bmu, axis=1) ### neighborhood-weighting computation note: ### internally, calculation of neighborhood weighting depends on 1st calculating ### L2 distance in Cartesian coordinate-space, then applying the neighborhood ### over these coordinate distance values # (B, n_units_xy) bmu_dist = jnp.linalg.norm(delta, axis=2) if self.neighbor_fx == 1: ## apply Mexican-hat kernel neighbor_weights = _ricker_marr_kernel(bmu_dist, sigma=radius) else: ## apply Gaussian kernel neighbor_weights = _gaussian_kernel(bmu_dist, sigma=radius) ## TODO: add in triangular, bubble, & laplacian kernels # (B, n_units_xy) return neighbor_weights
[docs] @compilable def advance_state(self): ## forward-inference step of SOM bmu_idx, delta = self._calc_bmu() self.bmu.set(bmu_idx) ## store BMU self.delta.set(delta) ## store delta/differences neighbor_weights = self._calc_neighborhood_weights() self.neighbor_weights.set(neighbor_weights) ## store neighborhood weightings ## compute an approximate weighted activity output for input pattern #activity = jnp.sum(self.weights * self.resist_scale * neighbor_weights, axis=1, keepdims=True) ### obtain weighted competitive activations (via softmax probs) activity = softmax(neighbor_weights * self.resist_scale) self.outputs.set(activity)
[docs] @compilable def evolve(self, t, dt): ## competitive Hebbian update step of SOM # #bmu = self.bmu.get() ## best-matching unit # delta = self.delta.get() ## deltas/differences between input & all SOM templates # neighbor_weights = self.neighbor_weights.get() ## get neighborhood weight values # ## exponential decay -> dz/dt = -kz has sol'n: z0 exp(-k t) # #t = self.i_tick.get() # ## update radius # r = self.radius.get() # r = r + (-r) * (1./self.tau_radius) # self.radius.set(r) # ## update learning rate alpha # a = self.eta.get() # a = a + (-a) * (1./self.tau_eta) # self.eta.set(a) # # self.radius.set(self.initial_radius * jnp.exp(-self.i_tick.get() / self.C)) ## update radius # # self.eta.set(self.initial_eta * jnp.exp(-self.i_tick.get() / self.iterations)) ## update learning rate alpha # dWeights = delta * neighbor_weights * self.eta.get() ## calculate change-in-synapses # self.dWeights.set(dWeights) # _W = self.weights.get() + dWeights ## update via competitive Hebbian rule # self.weights.set(_W) # self.i_tick.set(self.i_tick.get() + 1) ### Viet: My batchified code # (B, n_inputs, n_units_xy) delta = self.delta.get() ## deltas/differences between input & all SOM templates # (B, n_units_xy) neighbor_weights = self.neighbor_weights.get() ## get neighborhood weight values # NOTE: Viet: since we are doing batch mode, do we need to scale the updates by the batch size? # update neighborhood radius r = self.radius.get() r = r + (-r) * (1. / self.tau_radius) self.radius.set(r) ## update learning rate alpha a = self.eta.get() a = a + (-a) * (1. / self.tau_eta) self.eta.set(a) # Update weights # (B, n_inputs, n_units_xy) * (B, 1, n_units_xy) -> (B, n_inputs, n_units_xy) dWeights = delta * jnp.expand_dims(neighbor_weights, axis=1) * self.eta.get() # NOTE: Viet: since we are doing batch mode, we need to average the updates across the # batch dimension dWeights = dWeights.mean(axis=0) ## (n_inputs, n_units_xy) self.dWeights.set(dWeights) _W = self.weights.get() + dWeights ## update via competitive Hebbian rule self.weights.set(_W) # update tick. NOTE: are we ticking by batch size or 1? self.i_tick.set(self.i_tick.get() + 1)
[docs] @compilable def reset(self): preVals = jnp.zeros((self.batch_size, self.shape[0])) postVals = jnp.zeros((self.batch_size, self.shape[1])) if not self.inputs.targeted: self.inputs.set(preVals) self.outputs.set(postVals) self.dWeights.set(jnp.zeros(self.shape)) self.delta.set(jnp.zeros((self.batch_size, self.shape[0], self.shape[1]))) self.bmu.set(jnp.zeros((self.batch_size, 1), dtype=jnp.int32)) self.neighbor_weights.set(jnp.zeros((self.batch_size, self.shape[1])))
[docs] @classmethod def help(cls): ## component help function properties = { "synapse_type": "SOMSynapse - performs an adaptable synaptic transformation of inputs to produce output " "signals; synapses are adjusted via competitive Hebbian learning in accordance with a " "Kohonen map" } compartment_props = { "input_compartments": {"inputs": "Takes in external input signal values", "key": "JAX PRNG key"}, "parameter_compartments": {"weights": "Synapse efficacy/strength parameter values"}, "output_compartments": {"outputs": "Output of synaptic transformation", "bmu": "Best-matching unit (BMU)"}, } hyperparams = { "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", "batch_size": "Batch size dimension of this component", "weight_init": "Initialization conditions for synaptic weight (W) values", "resist_scale": "Resistance level scaling factor (applied to output of transformation)", "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", "eta": "Global learning rate", "radius": "Radius parameter to control influence of neighborhood function", "distance_function": "Distance function used to compute BMU" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "outputs = [W * alpha(bmu)] ;" "dW = SOM competitive Hebbian update", "hyperparameters": hyperparams} return info
# if __name__ == '__main__': # from ngcsimlib.context import Context # with Context("Bar") as bar: # Wab = SOMSynapse("Wab", (2, 3), 4, 4, 1.) # print(Wab)