Coverage for /home/ubuntu/flatiron/python/flatiron/tf/tools.py: 100%
47 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Any, Optional # noqa F401
2from flatiron.core.dataset import Dataset # noqa F401
3from flatiron.core.types import Compiled, Filepath, Getter # noqa: F401
5from copy import deepcopy
6import math
8from tensorflow import keras # noqa F401
9from keras import callbacks as tfcallbacks
10import tensorflow as tf
12import flatiron.core.tools as fict
13import flatiron.tf.loss as fi_tfloss
14import flatiron.tf.metric as fi_tfmetric
15import flatiron.tf.optimizer as fi_tfoptim
17Callbacks = dict[str, tfcallbacks.TensorBoard | tfcallbacks.ModelCheckpoint]
18# ------------------------------------------------------------------------------
21def get(config, module, fallback_module):
22 # type: (Getter, str, str) -> Any
23 '''
24 Given a config and set of modules return an instance or function.
26 Args:
27 config (dict): Instance config.
28 module (str): Always __name__.
29 fallback_module (str): Fallback module, either a tf or torch module.
31 Raises:
32 EnforceError: If config is not a dict with a name key.
34 Returns:
35 object: Instance or function.
36 '''
37 fict.enforce_getter(config)
38 # --------------------------------------------------------------------------
40 config = deepcopy(config)
41 name = config.pop('name')
42 try:
43 return fict.get_module_function(name, module)
44 except NotImplementedError:
45 pass
47 try:
48 mod = fict.get_module(fallback_module)
49 return mod.get(dict(class_name=name, config=config))
50 except ValueError:
51 pass
53 return fict.get_module_class(name, fallback_module)(**config)
56def get_callbacks(log_directory, checkpoint_pattern, checkpoint_params={}):
57 # type: (Filepath, str, dict) -> Callbacks
58 '''
59 Create a list of callbacks for Tensoflow model.
61 Args:
62 log_directory (str or Path): Tensorboard project log directory.
63 checkpoint_pattern (str): Filepath pattern for checkpoint callback.
64 checkpoint_params (dict, optional): Params to be passed to checkpoint
65 callback. Default: {}.
67 Raises:
68 EnforceError: If log directory does not exist.
69 EnforeError: If checkpoint pattern does not contain '{epoch}'.
71 Returns:
72 dict: dict with Tensorboard and ModelCheckpoint callbacks.
73 '''
74 fict.enforce_callbacks(log_directory, checkpoint_pattern)
75 return dict(
76 tensorboard=tfcallbacks.TensorBoard(
77 log_dir=log_directory, histogram_freq=1, update_freq=1
78 ),
79 checkpoint=tfcallbacks.ModelCheckpoint(checkpoint_pattern, **checkpoint_params),
80 )
83def pre_build(device):
84 # type: (str) -> None
85 '''
86 Sets hardware device.
88 Args:
89 device (str): Hardware device.
90 '''
91 if device == 'cpu':
92 tf.config.set_visible_devices([], 'GPU')
95def compile(framework, model, optimizer, loss, metrics):
96 # type: (Getter, Any, Getter, Getter, list[Getter]) -> Getter
97 '''
98 Call `modile.compile` on given model with kwargs.
100 Args:
101 framework (dict): Framework dict.
102 model (Any): Model to be compiled.
103 optimizer (dict): Optimizer settings.
104 loss (dict): Loss to be compiled.
105 metrics (list[dict]): Metrics function to be compiled.
107 Returns:
108 dict: Dict of compiled objects.
109 '''
110 framework.pop('name')
111 framework.pop('device')
112 model.compile(
113 optimizer=fi_tfoptim.get(optimizer),
114 loss=fi_tfloss.get(loss),
115 metrics=[fi_tfmetric.get(m) for m in metrics],
116 **framework,
117 )
118 return dict(model=model)
121def train(
122 compiled, # type: Compiled
123 callbacks, # type: Callbacks
124 train_data, # type: Dataset
125 test_data, # type: Optional[Dataset]
126 params, # type: dict
127):
128 # type: (...) -> None
129 '''
130 Train TensorFlow model.
132 Args:
133 compiled (dict): Compiled objects.
134 callbacks (dict): Dict of callbacks.
135 train_data (Dataset): Training dataset.
136 test_data (Dataset): Test dataset.
137 params (dict): Training params.
138 '''
139 batch_size = params['batch_size']
140 model = compiled['model']
141 x_train, y_train = train_data.xy_split()
142 steps = math.ceil(x_train.shape[0] / batch_size)
144 val = None
145 if test_data is not None:
146 val = test_data.xy_split()
148 model.fit(
149 x=x_train,
150 y=y_train,
151 callbacks=list(callbacks.values()),
152 validation_data=val,
153 steps_per_epoch=steps,
154 batch_size=params.get('batch_size', None),
155 epochs=params.get('epochs', 1),
156 verbose=params.get('verbose', 'auto'),
157 validation_split=params.get('validation_split', 0.0),
158 shuffle=params.get('shuffle', True),
159 initial_epoch=params.get('initial_epoch', 0),
160 validation_freq=params.get('validation_freq', 1),
161 # class_weight=train.get('class_weight', None),
162 # sample_weight=train.get('sample_weight', None),
163 # validation_steps=train.get('validation_steps', None),
164 # validation_batch_size=train.get('validation_batch_size', None),
165 )