Source code for ngclearn.utils.JaxProcessesMixin

from ngcsimlib import JointProcess, MethodProcess
from ngcsimlib.global_state import stateManager
import jax
from typing import TYPE_CHECKING

from ngcsimlib._src.parser.utils import CompiledMethod

if TYPE_CHECKING:
    from ngcsimlib._src.process.baseProcess import BaseProcess

[docs] class JaxCompiledMethod(CompiledMethod): """ A wrapper for a compiled method that includes jax's jit wrapped. Used exclusively by the mixin and shouldn't be used elsewhere. """ def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals): super().__init__(fn, fn_ast, auxiliary_ast, namespace, extra_globals) self._fn = jax.jit(fn) self._fn_source = fn @property def source_fn(self): """ The source method not wrapped in jit """ return self._fn_source
[docs] @classmethod def wrap(cls, compiledMethod: CompiledMethod): """ Helper method to expand on a base compiled method Args: compiledMethod: The method to be expanded upon Returns: the JaxCompiledMethod based on the input """ return cls(compiledMethod._fn, compiledMethod.ast, compiledMethod.auxiliary_ast, compiledMethod.namespace, compiledMethod.extra_globals)
[docs] class JaxProcessesMixin: """ A mixin for the base Process that adds JAX functionality such as scan and implicit jit wrapping """ def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs): """ Look at the BaseProcess class for information about other arguments Args: use_jit: a flag for if the process should implicitly jit wrap """ super().__init__(name, *args, **kwargs) self._previous_result = None self._previous_state = None self._use_jit = use_jit @property def previous_result(self): """ Stores and returns the last result of scan (the second returned value) """ return self._previous_result @property def previous_state(self): """ Stores and returns the last returned state of scan (the first returned value) """ return self._previous_state
[docs] def clear(self): """ Clears out the previous result and state from scan """ self._previous_result = None self._previous_state = None
[docs] def scan(self: "BaseProcess", inputs, current_state=None, store_state: bool = True, store_results: bool = True): """ Runs the process through jax's scan method Args: inputs: The inputs for scan (use pack rows to generate), must be a jax array current_state: Optional, the current state of the model, if none uses current global state store_state: Optional flag, should the final state be stored in the process store_results: Optional flag, should the final result be stored in the process Returns: the final state, the final result """ state = current_state or stateManager.state final_state, result = jax.lax.scan(self.run.compiled, state, inputs) if store_state: self._previous_state = final_state if store_results: self._previous_result = result return final_state, result
[docs] def compile(self: "baseProcess"): """ For use by the compiler """ super().compile() if self._use_jit: self.run.compiled = JaxCompiledMethod.wrap(self.run.compiled)
[docs] class JaxJointProcess(JaxProcessesMixin, JointProcess): pass
[docs] class JaxMethodProcess(JaxProcessesMixin, MethodProcess): pass