Source code for ngclearn.utils.feature_dictionaries.polynomialLibrary

from jax import jit, random, numpy as jnp
from typing import List, Tuple, Union
from dataclasses import dataclass

[docs] @dataclass class PolynomialLibrary: """ A class for creating polynomial feature libraries in 1D, 2D, or 3D. Args: poly_order (int): Maximum order of polynomial terms (Attribute) include_bias (bool): Whether to include the bias term in the output (Attribute) """ poly_order: int = None include_bias: bool = True def __post_init__(self): if self.poly_order is None: raise ValueError("poly_order must be specified") if self.poly_order < 0 or not isinstance(self.poly_order, int): raise ValueError("poly_order must be an integer") def _create_library(self, *arrays: jnp.ndarray) -> Tuple[jnp.ndarray, List[str]]: """ Create polynomial library for given input arrays. Args: arrays: Input arrays (x, y, z) Returns: Tuple of (feature matrix, feature names) """ dim = len(arrays) lib = jnp.ones_like(arrays[0]) names = ['1'] if dim == 1: for i in range(self.poly_order + 1): lib = jnp.concatenate([lib, arrays[0] ** i], axis=1) if not (i == 0): names.append(r'𝑥^{} |'.format(i)) if dim == 2: for i in range(self.poly_order + 1): for j in range(self.poly_order - i + 1): lib = jnp.concatenate([lib, arrays[0] ** i * arrays[1] ** j], axis=1) if not (i == 0 and j == 0): names.append(r'$𝑥^{} . 𝑦^{}$ |'.format(i, j)) if dim == 3: for i in range(self.poly_order + 1): for j in range(self.poly_order + 1 - i): for k in range(self.poly_order + 1 - (i + j)): lib = jnp.concatenate([lib, arrays[0] ** i * arrays[1] ** j * arrays[2] ** k], axis=1) if not (i == 0 and j == 0 and k == 0): names.append(r'$𝑥^{} . 𝑦^{} . 𝓏^{}$ |'.format(i, j, k)) return lib, names
[docs] def fit(self, X: List[jnp.ndarray]) -> Tuple[jnp.ndarray, List[str]]: """ Fits this library to a design matrix X Args: X: the design matrix to fit this library to Returns: the data-fit/retro-fit library """ if not 1 <= len(X) <=3: raise ValueError("Input must be 1D, 2D, or 3D; e.g. len(X) >= 1 ") arrays = [jnp.array(x).reshape(-1, 1) for x in X] lib, names = self._create_library(*arrays) start_idx = 1 if not self.include_bias else 0 return lib[:, start_idx+1:], names[start_idx:]