Coverage for /home/ubuntu/flatiron/python/flatiron/torch/optimizer.py: 100%
6 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from flatiron.core.types import Getter # noqa F401
3from torch.nn import Module # noqa: F401
5import flatiron.torch.tools as fi_torchtools
6# ------------------------------------------------------------------------------
9def get(config, model):
10 # type: (Getter, Module) -> Module
11 '''
12 Get function from this module.
14 Args:
15 config (dict): Optimizer config.
16 model (Module): Torch model.
18 Returns:
19 function: Module function.
20 '''
21 config['params'] = model.parameters()
22 return fi_torchtools.get(config, __name__, 'torch.optim')