Coverage for /home/ubuntu/flatiron/python/flatiron/core/tools.py: 100%

137 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Any, Callable, Optional, Union # noqa F401 

2from http.client import HTTPResponse # noqa F401 

3from lunchbox.stopwatch import StopWatch # noqa F401 

4from flatiron.core.types import Filepath, OptInt, OptFloat, Getter # noqa F401 

5import pandas as pd # noqa F401 

6 

7from datetime import datetime 

8from pathlib import Path 

9import inspect 

10import os 

11import random 

12import re 

13import sys 

14 

15from lunchbox.enforce import Enforce 

16import lunchbox.tools as lbt 

17import pytz 

18import yaml 

19# ------------------------------------------------------------------------------ 

20 

21 

22def get_tensorboard_project( 

23 project, root='/mnt/storage', timezone='UTC', extension='keras' 

24): 

25 # type: (Filepath, Filepath, str, str) -> dict[str, str] 

26 ''' 

27 Creates directory structure for Tensorboard project. 

28 

29 Args: 

30 project (str): Name of project. 

31 root (str or Path): Tensorboard parent directory. Default: /mnt/storage 

32 timezone (str, optional): Timezone. Default: UTC. 

33 extension (str, optional): File extension. 

34 Options: [keras, safetensors]. Default: keras. 

35 

36 Raises: 

37 EnforceError: If extension is not keras, pth or safetensors. 

38 

39 Returns: 

40 dict: Project details. 

41 ''' 

42 msg = 'Extension must be keras or safetensors. Given value: {a}.' 

43 Enforce(extension, 'in', ['keras', 'pth', 'safetensors'], message=msg) 

44 # -------------------------------------------------------------------------- 

45 

46 # create timestamp 

47 timestamp = datetime \ 

48 .now(tz=pytz.timezone(timezone)) \ 

49 .strftime('d-%Y-%m-%d_t-%H-%M-%S') 

50 

51 # create directories 

52 root_dir = Path(root, project, 'tensorboard').as_posix() 

53 log_dir = Path(root_dir, timestamp).as_posix() 

54 model_dir = Path(log_dir, 'models').as_posix() 

55 os.makedirs(root_dir, exist_ok=True) 

56 os.makedirs(model_dir, exist_ok=True) 

57 

58 # checkpoint pattern 

59 epoch = '{epoch:03d}' 

60 target = f'p-{project}_{timestamp}_e-{epoch}.{extension}' 

61 target = Path(model_dir, target).as_posix() 

62 

63 output = dict( 

64 root_dir=root_dir, 

65 log_dir=log_dir, 

66 model_dir=model_dir, 

67 checkpoint_pattern=target, 

68 ) 

69 return output 

70 

71 

72def enforce_callbacks(log_directory, checkpoint_pattern): 

73 # type: (Filepath, str) -> None 

74 ''' 

75 Enforces callback parameters. 

76 

77 Args: 

78 log_directory (str or Path): Tensorboard project log directory. 

79 checkpoint_pattern (str): Filepath pattern for checkpoint callback. 

80 

81 Raises: 

82 EnforceError: If log directory does not exist. 

83 EnforeError: If checkpoint pattern does not contain '{epoch}'. 

84 ''' 

85 log_dir = Path(log_directory) 

86 msg = f'Log directory: {log_dir} does not exist.' 

87 Enforce(log_dir.is_dir(), '==', True, message=msg) 

88 

89 match = re.search(r'\{epoch.*?\}', checkpoint_pattern) 

90 msg = "Checkpoint pattern must contain '{epoch}'. " 

91 msg += f'Given value: {checkpoint_pattern}' 

92 msg = msg.replace('{', '{{').replace('}', '}}') 

93 Enforce(match, '!=', None, message=msg) 

94 

95 

96def enforce_getter(value): 

97 # type: (Getter) -> None 

98 ''' 

99 Enforces value is a dict with a name key. 

100 

101 Args: 

102 value (dict): Dict.. 

103 

104 Raises: 

105 EnforceError: Is not a dict with a name key. 

106 ''' 

107 msg = 'Value must be a dict with a name key.' 

108 Enforce(value, 'instance of', dict, message=msg) 

109 Enforce('name', 'in', value, message=msg) 

110 

111 

112# MISC-------------------------------------------------------------------------- 

113def pad_layer_name(name, length=18): 

114 # type: (str, int) -> str 

115 ''' 

116 Pads underscores in a given layer name to make the string achieve a given 

117 length. 

118 

119 Args: 

120 name (str): Layer name to be padded. 

121 length (int): Length of output string. Default: 18. 

122 

123 Returns: 

124 str: Padded layer name. 

125 ''' 

126 if length == 0: 

127 return name 

128 

129 if '_' not in name: 

130 name += '_' 

131 delta = length - len(re.sub('_', '', name)) 

132 return re.sub('_+', '_' * delta, name) 

133 

134 

135def unindent(text, spaces=4): 

136 # type: (str, int) -> str 

137 ''' 

138 Unindents given block of text according to given number of spaces. 

139 

140 Args: 

141 text (str): Text block to unindent. 

142 spaces (int, optional): Number of spaces to remove. Default: 4. 

143 

144 Returns: 

145 str: Unindented text. 

146 ''' 

147 output = text.split('\n') # type: Any 

148 regex = re.compile('^ {' + str(spaces) + '}') 

149 output = [regex.sub('', x) for x in output] 

150 output = '\n'.join(output) 

151 return output 

152 

153 

154def slack_it( 

155 title, # type: str 

156 channel, # type: str 

157 url, # type: str 

158 config=None, # type: Optional[dict] 

159 stopwatch=None, # type: Optional[StopWatch] 

160 timezone='UTC', # type: str 

161 suppress=False, # type: bool 

162): 

163 # type: (...) -> Union[str, HTTPResponse] 

164 ''' 

165 Compose a message from given arguments and post it to slack. 

166 

167 Args: 

168 title (str): Post title. 

169 channel (str): Slack channel. 

170 url (str): Slack URL. 

171 config (dict, optional): Parameter dict. Default: None. 

172 stopwatch (StopWatch, optional): StopWatch instance. Default: None. 

173 timezone (str, optional): Timezone. Default: UTC. 

174 suppress (bool, optional): Return message, rather than post it to Slack. 

175 Default: False. 

176 

177 Returns: 

178 HTTPResponse: Slack response. 

179 ''' 

180 now = datetime.now(tz=pytz.timezone(timezone)).isoformat() 

181 cfg = config or {} 

182 delta = 'none' 

183 hdelta = 'none' 

184 if stopwatch is not None: 

185 hdelta = stopwatch.human_readable_delta 

186 delta = str(stopwatch.delta) 

187 

188 config_ = yaml.safe_dump(cfg, indent=4) 

189 message = f''' 

190 {title.upper()} 

191 

192 RUN TIME: 

193 ```{hdelta} ({delta})``` 

194 POST TIME: 

195 ```{now}``` 

196 CONFIG: 

197 ```{config_}``` 

198 '''[1:-1] 

199 message = unindent(message, spaces=8) 

200 

201 if suppress: 

202 return message 

203 return lbt.post_to_slack(url, channel, message) # pragma: no cover 

204 

205 

206def resolve_kwargs(kwargs, engine, optimizer, return_type='both'): 

207 # type: (dict, str, str, str) -> dict 

208 ''' 

209 Filter keyword arguments base on prefix and return them minus the prefix. 

210 

211 Args: 

212 kwargs (dict): Kwargs dict. 

213 engine (str): Deep learning framework. 

214 optimizer (str): Optimizer name. 

215 return_type (str, optional): Which kind of keys to return. 

216 Options: [prefixed, unprefixed, both]. Default: both. 

217 

218 Returns: 

219 dict: Resolved kwargs. 

220 ''' 

221 prefixed = {} 

222 unprefixed = {} 

223 for key, val in kwargs.items(): 

224 if not re.search('__', key): 

225 unprefixed[key] = val 

226 continue 

227 

228 head, tail = re.split('__', key, maxsplit=1) 

229 if not re.search(f'{engine}|{optimizer}', head): 

230 continue 

231 

232 cond = [ 

233 head.startswith(optimizer), 

234 head == engine, 

235 f'{engine}_{optimizer}' == head, 

236 ] 

237 if any(cond): 

238 prefixed[tail] = val 

239 

240 if return_type == 'prefixed': 

241 return prefixed 

242 elif return_type == 'unprefixed': 

243 return unprefixed 

244 prefixed.update(unprefixed) 

245 return prefixed 

246 

247 

248def train_test_split(data, test_size=0.2, shuffle=True, seed=None, limit=None): 

249 # type: (pd.DataFrame, float, bool, OptFloat, OptInt) -> tuple[pd.DataFrame, pd.DataFrame] 

250 ''' 

251 Split DataFrame into train and test DataFrames. 

252 

253 Args: 

254 data (pd.DataFrame): DataFrame. 

255 test_size (float, optional): Test set size as a proportion. 

256 Default: 0.2. 

257 shuffle (bool, optional): Randomize data before splitting. 

258 Default: True. 

259 seed (int, optional): Seed number. Default: None. 

260 limit (int, optional): Limit the total length of train and test. 

261 Default: None. 

262 

263 Raises: 

264 EnforceError: If data is not a DataFrame. 

265 EnforceError: If test_size is not between 0 and 1. 

266 

267 Returns: 

268 tuple[pd.DataFrame, pd.DataFrame]: Train and test DataFrames. 

269 ''' 

270 Enforce(data, 'instance of', pd.DataFrame) 

271 Enforce(test_size, '>=', 0) 

272 Enforce(test_size, '<=', 1) 

273 # -------------------------------------------------------------------------- 

274 

275 index = data.index.tolist() 

276 if shuffle: 

277 if seed is not None: 

278 rand = random.Random() 

279 rand.seed(seed) 

280 rand.shuffle(index) 

281 else: 

282 random.shuffle(index) 

283 

284 if limit is not None: 

285 index = index[:limit] 

286 

287 k = int(len(index) * (1 - test_size)) 

288 return data.loc[index[:k]].copy(), data.loc[index[k:]].copy() 

289 

290 

291# MODULE-FUNCS------------------------------------------------------------------ 

292def get_module(name): 

293 # type: (str) -> Any 

294 ''' 

295 Get a module from a given name. 

296 

297 Args: 

298 name (str): Module name. 

299 

300 Raises: 

301 NotImplementedError: If module is not found. 

302 

303 Returns: 

304 object: Module. 

305 ''' 

306 try: 

307 return sys.modules[name] 

308 except KeyError: 

309 raise NotImplementedError(f'Module not found: {name}') 

310 

311 

312def get_module_function(name, module): 

313 # type: (str, str) -> Callable[[Any], Any] 

314 ''' 

315 Get a function from a given module. 

316 

317 Args: 

318 name (str): Function name. 

319 module (str): Module name. 

320 

321 Raises: 

322 NotImplementedError: If function is not found in module. 

323 

324 Returns: 

325 function: Module function. 

326 ''' 

327 members = inspect.getmembers(get_module(module)) 

328 funcs = dict(filter(lambda x: inspect.isfunction(x[1]), members)) 

329 if name in funcs: 

330 return funcs[name] 

331 raise NotImplementedError(f'Function not found: {name}') 

332 

333 

334def get_module_class(name, module): 

335 # type: (str, str) -> Any 

336 ''' 

337 Get a class from a given module. 

338 

339 Args: 

340 name (str): Class name. 

341 module (str): Module name. 

342 

343 Raises: 

344 NotImplementedError: If class is not found in module. 

345 

346 Returns: 

347 class: Module class. 

348 ''' 

349 members = inspect.getmembers(get_module(module)) 

350 classes = dict(filter(lambda x: inspect.isclass(x[1]), members)) 

351 if name in classes: 

352 return classes[name] 

353 raise NotImplementedError(f'Class not found: {name}') 

354 

355 

356def resolve_module_config(config, module): 

357 # type: (Getter, str) -> Getter 

358 ''' 

359 Given a config and set of modules return a validated dict. 

360 

361 Args: 

362 config (dict): Instance config. 

363 module (str): Always __name__. 

364 

365 Raises: 

366 EnforceError: If config is not a dict with a name key. 

367 

368 Returns: 

369 dict: Resolved config dict. 

370 ''' 

371 enforce_getter(config) 

372 # -------------------------------------------------------------------------- 

373 

374 model = get_module_class(config['name'], module) 

375 return model.model_validate(config).model_dump() 

376 

377 

378def is_custom_definition(config, module): 

379 # type: (Getter, str) -> bool 

380 ''' 

381 Determine whether config is of custom defined code. 

382 

383 Args: 

384 config (dict): Instance config. 

385 module (str): Always __name__. 

386 

387 Raises: 

388 EnforceError: If config is not a dict with a name key. 

389 

390 Returns: 

391 bool: True if config is of custom defined code. 

392 ''' 

393 enforce_getter(config) 

394 # -------------------------------------------------------------------------- 

395 

396 try: 

397 get_module_function(config['name'], module) 

398 return True 

399 except NotImplementedError: 

400 pass 

401 

402 try: 

403 get_module_class(config['name'], module) 

404 return True 

405 except NotImplementedError: 

406 pass 

407 return False