Coverage for /home/ubuntu/flatiron/python/flatiron/core/pipeline.py: 100%
117 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Any, Optional, Type # noqa F401
2from flatiron.core.types import Compiled, Filepath # noqa F401
3from pydantic import BaseModel # noqa F401
5from abc import ABC, abstractmethod
6from copy import deepcopy
7from pathlib import Path
9import yaml
11from flatiron.core.dataset import Dataset
12import flatiron.core.logging as filog
13import flatiron.core.tools as fict
14import flatiron.core.resolve as res
15# ------------------------------------------------------------------------------
18class PipelineBase(ABC):
19 @classmethod
20 def read_yaml(cls, filepath):
21 # type: (Filepath) -> PipelineBase
22 '''
23 Construct PipelineBase instance from given yaml file.
25 Args:
26 filepath (str or Path): YAML file.
28 Returns:
29 PipelineBase: PipelineBase instance.
30 '''
31 with open(filepath) as f:
32 config = yaml.safe_load(f)
33 return cls(config)
35 @classmethod
36 def from_string(cls, text):
37 # type: (str) -> PipelineBase
38 '''
39 Construct PipelineBase instance from given YAML text.
41 Args:
42 text (str): YAML text.
44 Returns:
45 PipelineBase: PipelineBase instance.
46 '''
47 config = yaml.safe_load(text)
48 return cls(config)
50 @classmethod
51 def generate_config(
52 cls,
53 framework='torch',
54 project='project-name',
55 callback_root='/tensorboard/parent/dir',
56 dataset='/mnt/data/dataset',
57 optimizer='SGD',
58 loss='CrossEntropyLoss',
59 metrics=['MeanMetric'],
60 ):
61 # type: (str, str, str, str, str, str, list[str]) -> None
62 '''
63 Prints a generated pipeline config based on given parameters.
65 Args:
66 framework (str): Framework name. Default: torch.
67 project (str): Project name. Default: project-name.
68 callback_root (str): Callback root path. Default: /tensorboard/parent/dir.
69 dataset (str): Dataset path. Default: /mnt/data/dataset.
70 optimizer (str): Optimizer name. Default: SGD.
71 loss (str): Loss name. Default: CrossEntropyLoss.
72 metrics (list[str]): Metric names. Default: ['MeanMetric'].
73 '''
74 config = res._generate_config(
75 framework=framework,
76 project=project,
77 callback_root=callback_root,
78 dataset=dataset,
79 optimizer=optimizer,
80 loss=loss,
81 metrics=metrics,
82 )
83 print(yaml.safe_dump(config))
84 # --------------------------------------------------------------------------
86 def __init__(self, config):
87 # type: (dict) -> None
88 '''
89 PipelineBase is a base class for machine learning pipelines.
91 Args:
92 config (dict): PipelineBase config.
93 '''
94 self.config = res.resolve_config(config, self.model_config())
96 # create Dataset instance
97 dconf = self.config['dataset']
98 src = dconf['source']
99 kwargs = dict(
100 ext_regex=dconf['ext_regex'],
101 labels=dconf['labels'],
102 label_axis=dconf['label_axis'],
103 calc_file_size=False,
104 )
105 if Path(src).is_file():
106 self.dataset = Dataset.read_csv(src, **kwargs)
107 else:
108 self.dataset = Dataset.read_directory(src, **kwargs)
110 self._compiled = {} # type: Compiled
111 self._train_data = None # type: Optional[Dataset]
112 self._test_data = None # type: Optional[Dataset]
113 self._loaded = False
114 # --------------------------------------------------------------------------
116 def _logger(self, method, message, config):
117 # type: (str, str, dict) -> filog.SlackLogger
118 '''
119 Retreives a logger given a message, config and slack flag.
121 Args:
122 method (str): Name of method calling logger.
123 message (str): Log message or Slack title.
124 config (dict): Config dict.
126 Returns:
127 ficl.SlackLogger: Configured logger instance.
128 '''
129 kwargs = deepcopy(self.config['logger'])
130 methods = kwargs['slack_methods']
131 del kwargs['slack_methods']
132 logger = filog.SlackLogger(message, config, **kwargs)
133 if method not in methods:
134 logger._message_func = None
135 logger._callback = None
136 return logger
138 def load(self):
139 # type: () -> PipelineBase
140 '''
141 Loads train and test datasets into memory.
142 Calls `load` on self._train_data and self._test_data.
144 Raises:
145 RuntimeError: If train and test data are not datasets.
147 Returns:
148 PipelineBase: Self.
149 '''
150 if self._train_data is None or self._test_data is None:
151 msg = 'Train and test data not loaded. '
152 msg += 'Please call train_test_split method first.'
153 raise RuntimeError(msg)
155 config = self.config['dataset']
156 kwargs = dict(
157 reshape=config['reshape'],
158 limit=config['limit'],
159 shuffle=config['shuffle'],
160 )
161 with self._logger('load', 'LOAD DATASETS', dict(dataset=config)):
162 self._train_data.load(**kwargs)
163 self._test_data.load(**kwargs)
165 self._loaded = True
166 return self
168 def unload(self):
169 # type: () -> PipelineBase
170 '''
171 Unload train and test datasets from memory.
172 Calls `unload` on self._train_data and self._test_data.
174 Raises:
175 RuntimeError: If train and test data are not datasets.
176 RuntimeError: If train and test data are not loaded.
178 Returns:
179 PipelineBase: Self.
180 '''
181 if self._train_data is None or self._test_data is None or not self._loaded:
182 msg = 'Train and test data not loaded. '
183 msg += 'Please call train_test_split, then load methods first.'
184 raise RuntimeError(msg)
186 config = self.config['dataset']
187 with self._logger('unload', 'UNLOAD DATASETS', dict(dataset=config)):
188 self._train_data.unload()
189 self._test_data.unload()
190 self._loaded = False
191 return self
193 def train_test_split(self):
194 # type: () -> PipelineBase
195 '''
196 Split dataset into train and test sets.
198 Assigns the following instance members:
200 * _train_data
201 * _test_data
203 Returns:
204 PipelineBase: Self.
205 '''
206 config = self.config['dataset']
207 with self._logger(
208 'train_test_split', 'TRAIN TEST SPLIT', dict(dataset=config)
209 ):
210 self._train_data, self._test_data = self.dataset.train_test_split(
211 test_size=config['test_size'],
212 limit=config['limit'],
213 shuffle=config['shuffle'],
214 seed=config['seed'],
215 )
216 return self
218 def build(self):
219 # type: () -> PipelineBase
220 '''
221 Build machine learning model and assign it to self.model.
222 Calls `self.model_func` with model params.
224 Returns:
225 PipelineBase: Self.
226 '''
227 self._engine.tools.pre_build(self.config['framework']['device'])
228 config = self.config['model']
229 with self._logger('build', 'BUILD MODEL', dict(model=config)):
230 self.model = self.model_func()(**config)
231 return self
233 @property
234 def _engine(self):
235 # type: () -> Any
236 '''
237 Uses config to retrieve flatiron engine subpackage.
239 Returns:
240 Any: flatiron.tf or flatiron.torch
241 '''
242 if self.config['framework']['name'] == 'tensorflow':
243 import flatiron.tf as __tf_engine
244 return __tf_engine
245 import flatiron.torch as __torch_engine
246 return __torch_engine
248 def compile(self):
249 # type: () -> PipelineBase
250 '''
251 Sets self._compiled to a dictionary of compiled objects.
253 Returns:
254 PipelineBase: Self.
255 '''
256 config = deepcopy(self.config)
257 msg = dict(
258 framework=config['framework'],
259 model=config['model'],
260 optimizer=config['optimizer'],
261 loss=config['loss'],
262 metrics=config['metrics'],
263 )
264 with self._logger('compile', 'COMPILE MODEL', msg):
265 self._compiled = self._engine.tools.compile(
266 framework=config['framework'],
267 model=self.model,
268 optimizer=config['optimizer'],
269 loss=config['loss'],
270 metrics=config['metrics'],
271 )
272 return self
274 def train(self):
275 # type: () -> PipelineBase
276 '''
277 Call model train function with params.
279 Returns:
280 PipelineBase: Self.
281 '''
282 engine = self._engine
284 callbacks = self.config['callbacks']
285 train = self.config['train']
286 log = self.config['logger']
287 ext = 'safetensors'
288 if self.config['framework']['name'] == 'tensorflow':
289 ext = 'keras'
291 with self._logger('train', 'TRAIN MODEL', self.config):
292 # create tensorboard
293 tb = fict.get_tensorboard_project(
294 project=callbacks['project'],
295 root=callbacks['root'],
296 timezone=log['timezone'],
297 extension=ext,
298 )
300 # create checkpoint params and callbacks
301 ckpt_params = deepcopy(callbacks)
302 del ckpt_params['project']
303 del ckpt_params['root']
304 callbacks = engine.tools.get_callbacks(
305 tb['log_dir'], tb['checkpoint_pattern'], ckpt_params,
306 )
308 # train model
309 engine.tools.train(
310 compiled=self._compiled,
311 callbacks=callbacks,
312 train_data=self._train_data,
313 test_data=self._test_data,
314 params=train,
315 )
316 return self
318 def run(self):
319 # type: () -> PipelineBase
320 '''
321 Run the following pipeline operations:
323 * build
324 * compile
325 * train_test_split
326 * load (for tensorflow only)
327 * train
329 Returns:
330 PipelineBase: Self.
331 '''
332 if self.config['framework']['name'] == 'tensorflow':
333 return self \
334 .build() \
335 .compile() \
336 .train_test_split() \
337 .load() \
338 .train()
340 return self \
341 .build() \
342 .compile() \
343 .train_test_split() \
344 .train()
346 @abstractmethod
347 def model_config(self):
348 # type: () -> Type[BaseModel]
349 '''
350 Subclasses of PipelineBase will need to define a config class for models
351 created in the build method.
353 Returns:
354 BaseModel: Pydantic BaseModel config class.
355 '''
356 pass # pragma: no cover
358 @abstractmethod
359 def model_func(self):
360 # type: () -> Any
361 '''
362 Subclasses of PipelineBase need to define a function that builds and
363 returns a machine learning model.
365 Returns:
366 object: Machine learning model.
367 '''
368 pass # pragma: no cover