Source code for ngclearn.utils.optim.optim_utils

import functools
from .sgd import sgd_step, sgd_init
from .nag import nag_step, nag_init
from .adam import adam_step, adam_init

[docs] def get_opt_init_fn(opt='adam'): return { 'adam': adam_init, 'nag': nag_init, 'sgd': sgd_init }[opt]
[docs] def get_opt_step_fn(opt='adam', **kwargs): ## **kwargs here is the hyper-parameters you want to pass in the optimization function return { 'adam': functools.partial(adam_step, **kwargs), 'nag': functools.partial(nag_step, **kwargs), 'sgd': functools.partial(sgd_step, **kwargs), }[opt]