Coverage for /home/ubuntu/flatiron/python/flatiron/core/pipeline.py: 100%

117 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Any, Optional, Type # noqa F401 

2from flatiron.core.types import Compiled, Filepath # noqa F401 

3from pydantic import BaseModel # noqa F401 

4 

5from abc import ABC, abstractmethod 

6from copy import deepcopy 

7from pathlib import Path 

8 

9import yaml 

10 

11from flatiron.core.dataset import Dataset 

12import flatiron.core.logging as filog 

13import flatiron.core.tools as fict 

14import flatiron.core.resolve as res 

15# ------------------------------------------------------------------------------ 

16 

17 

18class PipelineBase(ABC): 

19 @classmethod 

20 def read_yaml(cls, filepath): 

21 # type: (Filepath) -> PipelineBase 

22 ''' 

23 Construct PipelineBase instance from given yaml file. 

24 

25 Args: 

26 filepath (str or Path): YAML file. 

27 

28 Returns: 

29 PipelineBase: PipelineBase instance. 

30 ''' 

31 with open(filepath) as f: 

32 config = yaml.safe_load(f) 

33 return cls(config) 

34 

35 @classmethod 

36 def from_string(cls, text): 

37 # type: (str) -> PipelineBase 

38 ''' 

39 Construct PipelineBase instance from given YAML text. 

40 

41 Args: 

42 text (str): YAML text. 

43 

44 Returns: 

45 PipelineBase: PipelineBase instance. 

46 ''' 

47 config = yaml.safe_load(text) 

48 return cls(config) 

49 

50 @classmethod 

51 def generate_config( 

52 cls, 

53 framework='torch', 

54 project='project-name', 

55 callback_root='/tensorboard/parent/dir', 

56 dataset='/mnt/data/dataset', 

57 optimizer='SGD', 

58 loss='CrossEntropyLoss', 

59 metrics=['MeanMetric'], 

60 ): 

61 # type: (str, str, str, str, str, str, list[str]) -> None 

62 ''' 

63 Prints a generated pipeline config based on given parameters. 

64 

65 Args: 

66 framework (str): Framework name. Default: torch. 

67 project (str): Project name. Default: project-name. 

68 callback_root (str): Callback root path. Default: /tensorboard/parent/dir. 

69 dataset (str): Dataset path. Default: /mnt/data/dataset. 

70 optimizer (str): Optimizer name. Default: SGD. 

71 loss (str): Loss name. Default: CrossEntropyLoss. 

72 metrics (list[str]): Metric names. Default: ['MeanMetric']. 

73 ''' 

74 config = res._generate_config( 

75 framework=framework, 

76 project=project, 

77 callback_root=callback_root, 

78 dataset=dataset, 

79 optimizer=optimizer, 

80 loss=loss, 

81 metrics=metrics, 

82 ) 

83 print(yaml.safe_dump(config)) 

84 # -------------------------------------------------------------------------- 

85 

86 def __init__(self, config): 

87 # type: (dict) -> None 

88 ''' 

89 PipelineBase is a base class for machine learning pipelines. 

90 

91 Args: 

92 config (dict): PipelineBase config. 

93 ''' 

94 self.config = res.resolve_config(config, self.model_config()) 

95 

96 # create Dataset instance 

97 dconf = self.config['dataset'] 

98 src = dconf['source'] 

99 kwargs = dict( 

100 ext_regex=dconf['ext_regex'], 

101 labels=dconf['labels'], 

102 label_axis=dconf['label_axis'], 

103 calc_file_size=False, 

104 ) 

105 if Path(src).is_file(): 

106 self.dataset = Dataset.read_csv(src, **kwargs) 

107 else: 

108 self.dataset = Dataset.read_directory(src, **kwargs) 

109 

110 self._compiled = {} # type: Compiled 

111 self._train_data = None # type: Optional[Dataset] 

112 self._test_data = None # type: Optional[Dataset] 

113 self._loaded = False 

114 # -------------------------------------------------------------------------- 

115 

116 def _logger(self, method, message, config): 

117 # type: (str, str, dict) -> filog.SlackLogger 

118 ''' 

119 Retreives a logger given a message, config and slack flag. 

120 

121 Args: 

122 method (str): Name of method calling logger. 

123 message (str): Log message or Slack title. 

124 config (dict): Config dict. 

125 

126 Returns: 

127 ficl.SlackLogger: Configured logger instance. 

128 ''' 

129 kwargs = deepcopy(self.config['logger']) 

130 methods = kwargs['slack_methods'] 

131 del kwargs['slack_methods'] 

132 logger = filog.SlackLogger(message, config, **kwargs) 

133 if method not in methods: 

134 logger._message_func = None 

135 logger._callback = None 

136 return logger 

137 

138 def load(self): 

139 # type: () -> PipelineBase 

140 ''' 

141 Loads train and test datasets into memory. 

142 Calls `load` on self._train_data and self._test_data. 

143 

144 Raises: 

145 RuntimeError: If train and test data are not datasets. 

146 

147 Returns: 

148 PipelineBase: Self. 

149 ''' 

150 if self._train_data is None or self._test_data is None: 

151 msg = 'Train and test data not loaded. ' 

152 msg += 'Please call train_test_split method first.' 

153 raise RuntimeError(msg) 

154 

155 config = self.config['dataset'] 

156 kwargs = dict( 

157 reshape=config['reshape'], 

158 limit=config['limit'], 

159 shuffle=config['shuffle'], 

160 ) 

161 with self._logger('load', 'LOAD DATASETS', dict(dataset=config)): 

162 self._train_data.load(**kwargs) 

163 self._test_data.load(**kwargs) 

164 

165 self._loaded = True 

166 return self 

167 

168 def unload(self): 

169 # type: () -> PipelineBase 

170 ''' 

171 Unload train and test datasets from memory. 

172 Calls `unload` on self._train_data and self._test_data. 

173 

174 Raises: 

175 RuntimeError: If train and test data are not datasets. 

176 RuntimeError: If train and test data are not loaded. 

177 

178 Returns: 

179 PipelineBase: Self. 

180 ''' 

181 if self._train_data is None or self._test_data is None or not self._loaded: 

182 msg = 'Train and test data not loaded. ' 

183 msg += 'Please call train_test_split, then load methods first.' 

184 raise RuntimeError(msg) 

185 

186 config = self.config['dataset'] 

187 with self._logger('unload', 'UNLOAD DATASETS', dict(dataset=config)): 

188 self._train_data.unload() 

189 self._test_data.unload() 

190 self._loaded = False 

191 return self 

192 

193 def train_test_split(self): 

194 # type: () -> PipelineBase 

195 ''' 

196 Split dataset into train and test sets. 

197 

198 Assigns the following instance members: 

199 

200 * _train_data 

201 * _test_data 

202 

203 Returns: 

204 PipelineBase: Self. 

205 ''' 

206 config = self.config['dataset'] 

207 with self._logger( 

208 'train_test_split', 'TRAIN TEST SPLIT', dict(dataset=config) 

209 ): 

210 self._train_data, self._test_data = self.dataset.train_test_split( 

211 test_size=config['test_size'], 

212 limit=config['limit'], 

213 shuffle=config['shuffle'], 

214 seed=config['seed'], 

215 ) 

216 return self 

217 

218 def build(self): 

219 # type: () -> PipelineBase 

220 ''' 

221 Build machine learning model and assign it to self.model. 

222 Calls `self.model_func` with model params. 

223 

224 Returns: 

225 PipelineBase: Self. 

226 ''' 

227 self._engine.tools.pre_build(self.config['framework']['device']) 

228 config = self.config['model'] 

229 with self._logger('build', 'BUILD MODEL', dict(model=config)): 

230 self.model = self.model_func()(**config) 

231 return self 

232 

233 @property 

234 def _engine(self): 

235 # type: () -> Any 

236 ''' 

237 Uses config to retrieve flatiron engine subpackage. 

238 

239 Returns: 

240 Any: flatiron.tf or flatiron.torch 

241 ''' 

242 if self.config['framework']['name'] == 'tensorflow': 

243 import flatiron.tf as __tf_engine 

244 return __tf_engine 

245 import flatiron.torch as __torch_engine 

246 return __torch_engine 

247 

248 def compile(self): 

249 # type: () -> PipelineBase 

250 ''' 

251 Sets self._compiled to a dictionary of compiled objects. 

252 

253 Returns: 

254 PipelineBase: Self. 

255 ''' 

256 config = deepcopy(self.config) 

257 msg = dict( 

258 framework=config['framework'], 

259 model=config['model'], 

260 optimizer=config['optimizer'], 

261 loss=config['loss'], 

262 metrics=config['metrics'], 

263 ) 

264 with self._logger('compile', 'COMPILE MODEL', msg): 

265 self._compiled = self._engine.tools.compile( 

266 framework=config['framework'], 

267 model=self.model, 

268 optimizer=config['optimizer'], 

269 loss=config['loss'], 

270 metrics=config['metrics'], 

271 ) 

272 return self 

273 

274 def train(self): 

275 # type: () -> PipelineBase 

276 ''' 

277 Call model train function with params. 

278 

279 Returns: 

280 PipelineBase: Self. 

281 ''' 

282 engine = self._engine 

283 

284 callbacks = self.config['callbacks'] 

285 train = self.config['train'] 

286 log = self.config['logger'] 

287 ext = 'safetensors' 

288 if self.config['framework']['name'] == 'tensorflow': 

289 ext = 'keras' 

290 

291 with self._logger('train', 'TRAIN MODEL', self.config): 

292 # create tensorboard 

293 tb = fict.get_tensorboard_project( 

294 project=callbacks['project'], 

295 root=callbacks['root'], 

296 timezone=log['timezone'], 

297 extension=ext, 

298 ) 

299 

300 # create checkpoint params and callbacks 

301 ckpt_params = deepcopy(callbacks) 

302 del ckpt_params['project'] 

303 del ckpt_params['root'] 

304 callbacks = engine.tools.get_callbacks( 

305 tb['log_dir'], tb['checkpoint_pattern'], ckpt_params, 

306 ) 

307 

308 # train model 

309 engine.tools.train( 

310 compiled=self._compiled, 

311 callbacks=callbacks, 

312 train_data=self._train_data, 

313 test_data=self._test_data, 

314 params=train, 

315 ) 

316 return self 

317 

318 def run(self): 

319 # type: () -> PipelineBase 

320 ''' 

321 Run the following pipeline operations: 

322 

323 * build 

324 * compile 

325 * train_test_split 

326 * load (for tensorflow only) 

327 * train 

328 

329 Returns: 

330 PipelineBase: Self. 

331 ''' 

332 if self.config['framework']['name'] == 'tensorflow': 

333 return self \ 

334 .build() \ 

335 .compile() \ 

336 .train_test_split() \ 

337 .load() \ 

338 .train() 

339 

340 return self \ 

341 .build() \ 

342 .compile() \ 

343 .train_test_split() \ 

344 .train() 

345 

346 @abstractmethod 

347 def model_config(self): 

348 # type: () -> Type[BaseModel] 

349 ''' 

350 Subclasses of PipelineBase will need to define a config class for models 

351 created in the build method. 

352 

353 Returns: 

354 BaseModel: Pydantic BaseModel config class. 

355 ''' 

356 pass # pragma: no cover 

357 

358 @abstractmethod 

359 def model_func(self): 

360 # type: () -> Any 

361 ''' 

362 Subclasses of PipelineBase need to define a function that builds and 

363 returns a machine learning model. 

364 

365 Returns: 

366 object: Machine learning model. 

367 ''' 

368 pass # pragma: no cover