Coverage for /home/ubuntu/flatiron/python/flatiron/core/tools.py: 100%
137 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Any, Callable, Optional, Union # noqa F401
2from http.client import HTTPResponse # noqa F401
3from lunchbox.stopwatch import StopWatch # noqa F401
4from flatiron.core.types import Filepath, OptInt, OptFloat, Getter # noqa F401
5import pandas as pd # noqa F401
7from datetime import datetime
8from pathlib import Path
9import inspect
10import os
11import random
12import re
13import sys
15from lunchbox.enforce import Enforce
16import lunchbox.tools as lbt
17import pytz
18import yaml
19# ------------------------------------------------------------------------------
22def get_tensorboard_project(
23 project, root='/mnt/storage', timezone='UTC', extension='keras'
24):
25 # type: (Filepath, Filepath, str, str) -> dict[str, str]
26 '''
27 Creates directory structure for Tensorboard project.
29 Args:
30 project (str): Name of project.
31 root (str or Path): Tensorboard parent directory. Default: /mnt/storage
32 timezone (str, optional): Timezone. Default: UTC.
33 extension (str, optional): File extension.
34 Options: [keras, safetensors]. Default: keras.
36 Raises:
37 EnforceError: If extension is not keras, pth or safetensors.
39 Returns:
40 dict: Project details.
41 '''
42 msg = 'Extension must be keras or safetensors. Given value: {a}.'
43 Enforce(extension, 'in', ['keras', 'pth', 'safetensors'], message=msg)
44 # --------------------------------------------------------------------------
46 # create timestamp
47 timestamp = datetime \
48 .now(tz=pytz.timezone(timezone)) \
49 .strftime('d-%Y-%m-%d_t-%H-%M-%S')
51 # create directories
52 root_dir = Path(root, project, 'tensorboard').as_posix()
53 log_dir = Path(root_dir, timestamp).as_posix()
54 model_dir = Path(log_dir, 'models').as_posix()
55 os.makedirs(root_dir, exist_ok=True)
56 os.makedirs(model_dir, exist_ok=True)
58 # checkpoint pattern
59 epoch = '{epoch:03d}'
60 target = f'p-{project}_{timestamp}_e-{epoch}.{extension}'
61 target = Path(model_dir, target).as_posix()
63 output = dict(
64 root_dir=root_dir,
65 log_dir=log_dir,
66 model_dir=model_dir,
67 checkpoint_pattern=target,
68 )
69 return output
72def enforce_callbacks(log_directory, checkpoint_pattern):
73 # type: (Filepath, str) -> None
74 '''
75 Enforces callback parameters.
77 Args:
78 log_directory (str or Path): Tensorboard project log directory.
79 checkpoint_pattern (str): Filepath pattern for checkpoint callback.
81 Raises:
82 EnforceError: If log directory does not exist.
83 EnforeError: If checkpoint pattern does not contain '{epoch}'.
84 '''
85 log_dir = Path(log_directory)
86 msg = f'Log directory: {log_dir} does not exist.'
87 Enforce(log_dir.is_dir(), '==', True, message=msg)
89 match = re.search(r'\{epoch.*?\}', checkpoint_pattern)
90 msg = "Checkpoint pattern must contain '{epoch}'. "
91 msg += f'Given value: {checkpoint_pattern}'
92 msg = msg.replace('{', '{{').replace('}', '}}')
93 Enforce(match, '!=', None, message=msg)
96def enforce_getter(value):
97 # type: (Getter) -> None
98 '''
99 Enforces value is a dict with a name key.
101 Args:
102 value (dict): Dict..
104 Raises:
105 EnforceError: Is not a dict with a name key.
106 '''
107 msg = 'Value must be a dict with a name key.'
108 Enforce(value, 'instance of', dict, message=msg)
109 Enforce('name', 'in', value, message=msg)
112# MISC--------------------------------------------------------------------------
113def pad_layer_name(name, length=18):
114 # type: (str, int) -> str
115 '''
116 Pads underscores in a given layer name to make the string achieve a given
117 length.
119 Args:
120 name (str): Layer name to be padded.
121 length (int): Length of output string. Default: 18.
123 Returns:
124 str: Padded layer name.
125 '''
126 if length == 0:
127 return name
129 if '_' not in name:
130 name += '_'
131 delta = length - len(re.sub('_', '', name))
132 return re.sub('_+', '_' * delta, name)
135def unindent(text, spaces=4):
136 # type: (str, int) -> str
137 '''
138 Unindents given block of text according to given number of spaces.
140 Args:
141 text (str): Text block to unindent.
142 spaces (int, optional): Number of spaces to remove. Default: 4.
144 Returns:
145 str: Unindented text.
146 '''
147 output = text.split('\n') # type: Any
148 regex = re.compile('^ {' + str(spaces) + '}')
149 output = [regex.sub('', x) for x in output]
150 output = '\n'.join(output)
151 return output
154def slack_it(
155 title, # type: str
156 channel, # type: str
157 url, # type: str
158 config=None, # type: Optional[dict]
159 stopwatch=None, # type: Optional[StopWatch]
160 timezone='UTC', # type: str
161 suppress=False, # type: bool
162):
163 # type: (...) -> Union[str, HTTPResponse]
164 '''
165 Compose a message from given arguments and post it to slack.
167 Args:
168 title (str): Post title.
169 channel (str): Slack channel.
170 url (str): Slack URL.
171 config (dict, optional): Parameter dict. Default: None.
172 stopwatch (StopWatch, optional): StopWatch instance. Default: None.
173 timezone (str, optional): Timezone. Default: UTC.
174 suppress (bool, optional): Return message, rather than post it to Slack.
175 Default: False.
177 Returns:
178 HTTPResponse: Slack response.
179 '''
180 now = datetime.now(tz=pytz.timezone(timezone)).isoformat()
181 cfg = config or {}
182 delta = 'none'
183 hdelta = 'none'
184 if stopwatch is not None:
185 hdelta = stopwatch.human_readable_delta
186 delta = str(stopwatch.delta)
188 config_ = yaml.safe_dump(cfg, indent=4)
189 message = f'''
190 {title.upper()}
192 RUN TIME:
193 ```{hdelta} ({delta})```
194 POST TIME:
195 ```{now}```
196 CONFIG:
197 ```{config_}```
198 '''[1:-1]
199 message = unindent(message, spaces=8)
201 if suppress:
202 return message
203 return lbt.post_to_slack(url, channel, message) # pragma: no cover
206def resolve_kwargs(kwargs, engine, optimizer, return_type='both'):
207 # type: (dict, str, str, str) -> dict
208 '''
209 Filter keyword arguments base on prefix and return them minus the prefix.
211 Args:
212 kwargs (dict): Kwargs dict.
213 engine (str): Deep learning framework.
214 optimizer (str): Optimizer name.
215 return_type (str, optional): Which kind of keys to return.
216 Options: [prefixed, unprefixed, both]. Default: both.
218 Returns:
219 dict: Resolved kwargs.
220 '''
221 prefixed = {}
222 unprefixed = {}
223 for key, val in kwargs.items():
224 if not re.search('__', key):
225 unprefixed[key] = val
226 continue
228 head, tail = re.split('__', key, maxsplit=1)
229 if not re.search(f'{engine}|{optimizer}', head):
230 continue
232 cond = [
233 head.startswith(optimizer),
234 head == engine,
235 f'{engine}_{optimizer}' == head,
236 ]
237 if any(cond):
238 prefixed[tail] = val
240 if return_type == 'prefixed':
241 return prefixed
242 elif return_type == 'unprefixed':
243 return unprefixed
244 prefixed.update(unprefixed)
245 return prefixed
248def train_test_split(data, test_size=0.2, shuffle=True, seed=None, limit=None):
249 # type: (pd.DataFrame, float, bool, OptFloat, OptInt) -> tuple[pd.DataFrame, pd.DataFrame]
250 '''
251 Split DataFrame into train and test DataFrames.
253 Args:
254 data (pd.DataFrame): DataFrame.
255 test_size (float, optional): Test set size as a proportion.
256 Default: 0.2.
257 shuffle (bool, optional): Randomize data before splitting.
258 Default: True.
259 seed (int, optional): Seed number. Default: None.
260 limit (int, optional): Limit the total length of train and test.
261 Default: None.
263 Raises:
264 EnforceError: If data is not a DataFrame.
265 EnforceError: If test_size is not between 0 and 1.
267 Returns:
268 tuple[pd.DataFrame, pd.DataFrame]: Train and test DataFrames.
269 '''
270 Enforce(data, 'instance of', pd.DataFrame)
271 Enforce(test_size, '>=', 0)
272 Enforce(test_size, '<=', 1)
273 # --------------------------------------------------------------------------
275 index = data.index.tolist()
276 if shuffle:
277 if seed is not None:
278 rand = random.Random()
279 rand.seed(seed)
280 rand.shuffle(index)
281 else:
282 random.shuffle(index)
284 if limit is not None:
285 index = index[:limit]
287 k = int(len(index) * (1 - test_size))
288 return data.loc[index[:k]].copy(), data.loc[index[k:]].copy()
291# MODULE-FUNCS------------------------------------------------------------------
292def get_module(name):
293 # type: (str) -> Any
294 '''
295 Get a module from a given name.
297 Args:
298 name (str): Module name.
300 Raises:
301 NotImplementedError: If module is not found.
303 Returns:
304 object: Module.
305 '''
306 try:
307 return sys.modules[name]
308 except KeyError:
309 raise NotImplementedError(f'Module not found: {name}')
312def get_module_function(name, module):
313 # type: (str, str) -> Callable[[Any], Any]
314 '''
315 Get a function from a given module.
317 Args:
318 name (str): Function name.
319 module (str): Module name.
321 Raises:
322 NotImplementedError: If function is not found in module.
324 Returns:
325 function: Module function.
326 '''
327 members = inspect.getmembers(get_module(module))
328 funcs = dict(filter(lambda x: inspect.isfunction(x[1]), members))
329 if name in funcs:
330 return funcs[name]
331 raise NotImplementedError(f'Function not found: {name}')
334def get_module_class(name, module):
335 # type: (str, str) -> Any
336 '''
337 Get a class from a given module.
339 Args:
340 name (str): Class name.
341 module (str): Module name.
343 Raises:
344 NotImplementedError: If class is not found in module.
346 Returns:
347 class: Module class.
348 '''
349 members = inspect.getmembers(get_module(module))
350 classes = dict(filter(lambda x: inspect.isclass(x[1]), members))
351 if name in classes:
352 return classes[name]
353 raise NotImplementedError(f'Class not found: {name}')
356def resolve_module_config(config, module):
357 # type: (Getter, str) -> Getter
358 '''
359 Given a config and set of modules return a validated dict.
361 Args:
362 config (dict): Instance config.
363 module (str): Always __name__.
365 Raises:
366 EnforceError: If config is not a dict with a name key.
368 Returns:
369 dict: Resolved config dict.
370 '''
371 enforce_getter(config)
372 # --------------------------------------------------------------------------
374 model = get_module_class(config['name'], module)
375 return model.model_validate(config).model_dump()
378def is_custom_definition(config, module):
379 # type: (Getter, str) -> bool
380 '''
381 Determine whether config is of custom defined code.
383 Args:
384 config (dict): Instance config.
385 module (str): Always __name__.
387 Raises:
388 EnforceError: If config is not a dict with a name key.
390 Returns:
391 bool: True if config is of custom defined code.
392 '''
393 enforce_getter(config)
394 # --------------------------------------------------------------------------
396 try:
397 get_module_function(config['name'], module)
398 return True
399 except NotImplementedError:
400 pass
402 try:
403 get_module_class(config['name'], module)
404 return True
405 except NotImplementedError:
406 pass
407 return False