Source code for ngclearn.utils.JaxProcessesMixin

from ngcsimlib import JointProcess, MethodProcess
from ngcsimlib.global_state import stateManager
import jax
from typing import TYPE_CHECKING

from ngcsimlib._src.parser.utils import CompiledMethod

if TYPE_CHECKING:
    from ngcsimlib._src.process.baseProcess import BaseProcess

[docs] class JaxCompiledMethod(CompiledMethod): def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals): super().__init__(fn, fn_ast, auxiliary_ast, namespace, extra_globals) self._fn = jax.jit(fn) self._fn_source = fn @property def source_fn(self): return self._fn_source
[docs] @classmethod def wrap(cls, compiledMethod: CompiledMethod): return cls(compiledMethod._fn, compiledMethod.ast, compiledMethod.auxiliary_ast, compiledMethod.namespace, compiledMethod.extra_globals)
[docs] class JaxProcessesMixin: def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs): super().__init__(name, *args, **kwargs) self._previous_result = None self._previous_state = None self._use_jit = use_jit @property def previous_result(self): return self._previous_result @property def previous_state(self): return self._previous_state
[docs] def clear(self): self._previous_result = None self._previous_state = None
[docs] def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True): state = current_state or stateManager.state final_state, result = jax.lax.scan(self.run.compiled, state, inputs) if save_state: self._previous_state = final_state if store_results: self._previous_result = result return final_state, result
[docs] def compile(self: "baseProcess"): super().compile() if self._use_jit: self.run.compiled = JaxCompiledMethod.wrap(self.run.compiled)
[docs] class JaxJointProcess(JaxProcessesMixin, JointProcess): pass
[docs] class JaxMethodProcess(JaxProcessesMixin, MethodProcess): pass