Coverage for /home/ubuntu/flatiron/python/flatiron/torch/optimizer.py: 100%

6 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from flatiron.core.types import Getter # noqa F401 

2 

3from torch.nn import Module # noqa: F401 

4 

5import flatiron.torch.tools as fi_torchtools 

6# ------------------------------------------------------------------------------ 

7 

8 

9def get(config, model): 

10 # type: (Getter, Module) -> Module 

11 ''' 

12 Get function from this module. 

13 

14 Args: 

15 config (dict): Optimizer config. 

16 model (Module): Torch model. 

17 

18 Returns: 

19 function: Module function. 

20 ''' 

21 config['params'] = model.parameters() 

22 return fi_torchtools.get(config, __name__, 'torch.optim')