Coverage for /home/ubuntu/flatiron/python/flatiron/core/resolve.py: 100%

65 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Type # noqa F401 

2from flatiron.core.types import OptStr, Getter # noqa F401 

3from pydantic import BaseModel # noqa F401 

4 

5from copy import deepcopy 

6 

7import flatiron.core.config as cfg 

8import flatiron.core.tools as fict 

9# ------------------------------------------------------------------------------ 

10 

11 

12def resolve_config(config, model): 

13 # type: (dict, Type[BaseModel]) -> dict 

14 ''' 

15 Resolves given Pipeline config. 

16 Config fields include: 

17 

18 * framework 

19 * model 

20 * dataset 

21 * optimizer 

22 * loss 

23 * metrics 

24 * callbacks 

25 * train 

26 * logger 

27 

28 Args: 

29 config (dict): Config dict. 

30 model (BaseModel): Model config class. 

31 

32 Returns: 

33 dict: Resolved config. 

34 ''' 

35 config = deepcopy(config) 

36 config = _resolve_model(config, model) 

37 config = _resolve_pipeline(config) 

38 config = _resolve_field(config, 'framework') 

39 config = _resolve_field(config, 'optimizer') 

40 config = _resolve_field(config, 'loss') 

41 config = _resolve_field(config, 'metrics') 

42 return config 

43 

44 

45def _generate_config( 

46 framework='torch', 

47 project='project-name', 

48 callback_root='/tensorboard/parent/dir', 

49 dataset='/mnt/data/dataset', 

50 optimizer='SGD', 

51 loss='CrossEntropyLoss', 

52 metrics=['MeanMetric'], 

53): 

54 # type: (str, str, str, str, str, str, list[str]) -> dict 

55 ''' 

56 Generate a pipeline config based on given parameters. 

57 

58 Args: 

59 framework (str): Framework name. Default: torch. 

60 project (str): Project name. Default: project-name. 

61 callback_root (str): Callback root path. Default: /tensorboard/parent/dir. 

62 dataset (str): Dataset path. Default: /mnt/data/dataset. 

63 optimizer (str): Optimizer name. Default: SGD. 

64 loss (str): Loss name. Default: CrossEntropyLoss. 

65 metrics (list[str]): Metric names. Default: ['MeanMetric']. 

66 

67 Returns: 

68 dict: Generated config. 

69 ''' 

70 if framework == 'tensorflow': 

71 if loss == 'CrossEntropyLoss': 

72 loss = 'CategoricalCrossentropy' 

73 if metrics == ['MeanMetric']: 

74 metrics = ['Mean'] 

75 config = dict( 

76 framework=dict(name=framework), 

77 dataset=dict(source=dataset), 

78 model=dict(), 

79 optimizer=dict(name=optimizer), 

80 loss=dict(name=loss), 

81 metrics=[dict(name=x) for x in metrics], 

82 callbacks=dict(project=project, root=callback_root), 

83 logger={}, 

84 train={}, 

85 ) 

86 config = _resolve_pipeline(config) 

87 config = _resolve_field(config, 'framework') 

88 config = _resolve_field(config, 'optimizer') 

89 config = _resolve_field(config, 'loss') 

90 config = _resolve_field(config, 'metrics') 

91 return config 

92 

93 

94def _resolve_model(config, model): 

95 # type: (dict, Type[BaseModel]) -> dict 

96 ''' 

97 Resolve and validate given model config. 

98 

99 Args: 

100 config (dict): Model config. 

101 model (BaseModel): Model config class. 

102 

103 Returns: 

104 dict: Validated model config. 

105 ''' 

106 config['model'] = model \ 

107 .model_validate(config['model'], strict=True) \ 

108 .model_dump() 

109 return config 

110 

111 

112def _resolve_pipeline(config): 

113 # type: (dict) -> dict 

114 ''' 

115 Resolve and validate given pipeline config. 

116 

117 Args: 

118 config (dict): Pipeline config. 

119 

120 Returns: 

121 dict: Validated pipeline config. 

122 ''' 

123 model = config.pop('model', {}) 

124 config = cfg.PipelineConfig \ 

125 .model_validate(config, strict=True) \ 

126 .model_dump() 

127 config['model'] = model 

128 return config 

129 

130 

131def _resolve_field(config, field): 

132 # type: (dict, str) -> dict 

133 ''' 

134 Resolve and validate given pipeline config field. 

135 

136 Args: 

137 config (dict): Pipeline config. 

138 field (str): Config field name. 

139 

140 Returns: 

141 dict: Updated pipeline config. 

142 ''' 

143 prefix = config['framework']['name'] 

144 if prefix == 'tensorflow': 

145 prefix = 'TF' 

146 else: 

147 prefix = prefix.capitalize() 

148 

149 pkg = f'flatiron.{prefix.lower()}' 

150 lut = dict( 

151 framework=(f'{prefix}Framework', False, f'{pkg}.config', None ), # noqa E202 

152 optimizer=(f'{prefix}Opt', True, f'{pkg}.config', f'{pkg}.optimizer'), # noqa E202 

153 loss =(f'{prefix}Loss', True, f'{pkg}.config', f'{pkg}.loss' ), # noqa E202 

154 metrics =(f'{prefix}Metric', True, f'{pkg}.config', f'{pkg}.metric' ), # noqa E202 

155 ) 

156 keys = ['class_prefix', 'prepend', 'config_module', 'other_module'] 

157 kwargs = dict(zip(keys, lut[field])) # type: Getter 

158 

159 subconfig = config[field] 

160 if isinstance(subconfig, list): 

161 config[field] = [_resolve_subconfig(x, **kwargs) for x in subconfig] 

162 else: 

163 config[field] = _resolve_subconfig(subconfig, **kwargs) 

164 

165 return config 

166 

167 

168def _resolve_subconfig( 

169 subconfig, class_prefix, prepend, config_module, other_module 

170): 

171 # type: (dict, str, bool, str, OptStr) -> dict 

172 ''' 

173 For use in _resolve_field. Resolves and validates given subconfig. 

174 If class is not custom definition found in config module or 

175 other module, a standard definition will be resolved from config module. 

176 class prefix and prepend are used to modify the config name field in 

177 order to make it a valid class name. 

178 

179 Args: 

180 subconfig (dict): Subconfig. 

181 class_prefix (str): Class prefix. 

182 prepend (bool): Prepend class prefix. 

183 config_module (str): Module name. 

184 other_module (str): Module name. 

185 

186 Returns: 

187 dict: Validated subconfig. 

188 ''' 

189 if config_module is not None: 

190 if fict.is_custom_definition(subconfig, config_module): 

191 return subconfig 

192 if other_module is not None: 

193 if fict.is_custom_definition(subconfig, other_module): 

194 return subconfig 

195 

196 name = subconfig['name'] 

197 output = deepcopy(subconfig) 

198 output['name'] = class_prefix 

199 if prepend: 

200 output['name'] += name 

201 

202 output = fict.resolve_module_config(output, config_module) 

203 output['name'] = name 

204 return output