Coverage for /home/ubuntu/flatiron/python/flatiron/core/resolve.py: 100%
65 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Type # noqa F401
2from flatiron.core.types import OptStr, Getter # noqa F401
3from pydantic import BaseModel # noqa F401
5from copy import deepcopy
7import flatiron.core.config as cfg
8import flatiron.core.tools as fict
9# ------------------------------------------------------------------------------
12def resolve_config(config, model):
13 # type: (dict, Type[BaseModel]) -> dict
14 '''
15 Resolves given Pipeline config.
16 Config fields include:
18 * framework
19 * model
20 * dataset
21 * optimizer
22 * loss
23 * metrics
24 * callbacks
25 * train
26 * logger
28 Args:
29 config (dict): Config dict.
30 model (BaseModel): Model config class.
32 Returns:
33 dict: Resolved config.
34 '''
35 config = deepcopy(config)
36 config = _resolve_model(config, model)
37 config = _resolve_pipeline(config)
38 config = _resolve_field(config, 'framework')
39 config = _resolve_field(config, 'optimizer')
40 config = _resolve_field(config, 'loss')
41 config = _resolve_field(config, 'metrics')
42 return config
45def _generate_config(
46 framework='torch',
47 project='project-name',
48 callback_root='/tensorboard/parent/dir',
49 dataset='/mnt/data/dataset',
50 optimizer='SGD',
51 loss='CrossEntropyLoss',
52 metrics=['MeanMetric'],
53):
54 # type: (str, str, str, str, str, str, list[str]) -> dict
55 '''
56 Generate a pipeline config based on given parameters.
58 Args:
59 framework (str): Framework name. Default: torch.
60 project (str): Project name. Default: project-name.
61 callback_root (str): Callback root path. Default: /tensorboard/parent/dir.
62 dataset (str): Dataset path. Default: /mnt/data/dataset.
63 optimizer (str): Optimizer name. Default: SGD.
64 loss (str): Loss name. Default: CrossEntropyLoss.
65 metrics (list[str]): Metric names. Default: ['MeanMetric'].
67 Returns:
68 dict: Generated config.
69 '''
70 if framework == 'tensorflow':
71 if loss == 'CrossEntropyLoss':
72 loss = 'CategoricalCrossentropy'
73 if metrics == ['MeanMetric']:
74 metrics = ['Mean']
75 config = dict(
76 framework=dict(name=framework),
77 dataset=dict(source=dataset),
78 model=dict(),
79 optimizer=dict(name=optimizer),
80 loss=dict(name=loss),
81 metrics=[dict(name=x) for x in metrics],
82 callbacks=dict(project=project, root=callback_root),
83 logger={},
84 train={},
85 )
86 config = _resolve_pipeline(config)
87 config = _resolve_field(config, 'framework')
88 config = _resolve_field(config, 'optimizer')
89 config = _resolve_field(config, 'loss')
90 config = _resolve_field(config, 'metrics')
91 return config
94def _resolve_model(config, model):
95 # type: (dict, Type[BaseModel]) -> dict
96 '''
97 Resolve and validate given model config.
99 Args:
100 config (dict): Model config.
101 model (BaseModel): Model config class.
103 Returns:
104 dict: Validated model config.
105 '''
106 config['model'] = model \
107 .model_validate(config['model'], strict=True) \
108 .model_dump()
109 return config
112def _resolve_pipeline(config):
113 # type: (dict) -> dict
114 '''
115 Resolve and validate given pipeline config.
117 Args:
118 config (dict): Pipeline config.
120 Returns:
121 dict: Validated pipeline config.
122 '''
123 model = config.pop('model', {})
124 config = cfg.PipelineConfig \
125 .model_validate(config, strict=True) \
126 .model_dump()
127 config['model'] = model
128 return config
131def _resolve_field(config, field):
132 # type: (dict, str) -> dict
133 '''
134 Resolve and validate given pipeline config field.
136 Args:
137 config (dict): Pipeline config.
138 field (str): Config field name.
140 Returns:
141 dict: Updated pipeline config.
142 '''
143 prefix = config['framework']['name']
144 if prefix == 'tensorflow':
145 prefix = 'TF'
146 else:
147 prefix = prefix.capitalize()
149 pkg = f'flatiron.{prefix.lower()}'
150 lut = dict(
151 framework=(f'{prefix}Framework', False, f'{pkg}.config', None ), # noqa E202
152 optimizer=(f'{prefix}Opt', True, f'{pkg}.config', f'{pkg}.optimizer'), # noqa E202
153 loss =(f'{prefix}Loss', True, f'{pkg}.config', f'{pkg}.loss' ), # noqa E202
154 metrics =(f'{prefix}Metric', True, f'{pkg}.config', f'{pkg}.metric' ), # noqa E202
155 )
156 keys = ['class_prefix', 'prepend', 'config_module', 'other_module']
157 kwargs = dict(zip(keys, lut[field])) # type: Getter
159 subconfig = config[field]
160 if isinstance(subconfig, list):
161 config[field] = [_resolve_subconfig(x, **kwargs) for x in subconfig]
162 else:
163 config[field] = _resolve_subconfig(subconfig, **kwargs)
165 return config
168def _resolve_subconfig(
169 subconfig, class_prefix, prepend, config_module, other_module
170):
171 # type: (dict, str, bool, str, OptStr) -> dict
172 '''
173 For use in _resolve_field. Resolves and validates given subconfig.
174 If class is not custom definition found in config module or
175 other module, a standard definition will be resolved from config module.
176 class prefix and prepend are used to modify the config name field in
177 order to make it a valid class name.
179 Args:
180 subconfig (dict): Subconfig.
181 class_prefix (str): Class prefix.
182 prepend (bool): Prepend class prefix.
183 config_module (str): Module name.
184 other_module (str): Module name.
186 Returns:
187 dict: Validated subconfig.
188 '''
189 if config_module is not None:
190 if fict.is_custom_definition(subconfig, config_module):
191 return subconfig
192 if other_module is not None:
193 if fict.is_custom_definition(subconfig, other_module):
194 return subconfig
196 name = subconfig['name']
197 output = deepcopy(subconfig)
198 output['name'] = class_prefix
199 if prepend:
200 output['name'] += name
202 output = fict.resolve_module_config(output, config_module)
203 output['name'] = name
204 return output