Coverage for /home/ubuntu/flatiron/python/flatiron/torch/tools.py: 98%
148 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Any, Callable, Optional # noqa F401
2from flatiron.core.dataset import Dataset # noqa: F401
3from flatiron.core.types import Compiled, Filepath, Getter # noqa F401
5from copy import deepcopy
6from pathlib import Path
8from torch.utils.tensorboard import SummaryWriter
9import lunchbox.tools as lbt
10import pandas as pd
11import numpy as np
12import safetensors.torch as safetensors
13import tqdm.notebook as tqdm
14import torch
15import torch.utils.data as torchdata
17import flatiron.core.tools as fict
18import flatiron.torch.loss as fi_torchloss
19import flatiron.torch.metric as fi_torchmetric
20import flatiron.torch.optimizer as fi_torchoptim
21# ------------------------------------------------------------------------------
24def resolve_config(config):
25 # type: (dict) -> dict
26 '''
27 Resolve configs handed to Torch classes. Replaces the following:
29 * learning_rate
30 * epsilon
31 * clipping_threshold
32 * exponent
33 * norm_degree
34 * beta_1
35 * beta_2
37 Args:
38 config (dict): Config dict.
40 Returns:
41 dict: Resolved config.
42 '''
43 params = config.pop('params', None)
44 output = deepcopy(config)
45 if params is not None:
46 output['params'] = params
48 lut = dict(
49 learning_rate='lr',
50 epsilon='eps',
51 clipping_threshold='d',
52 exponent='p',
53 norm_degree='p',
54 )
55 for key, val in config.items():
56 if key in lut:
57 output[lut[key]] = val
58 del output[key]
60 if 'beta_1' in output or 'beta_2' in output:
61 beta_1 = output.pop('beta_1', 0.9)
62 beta_2 = output.pop('beta_2', 0.999)
63 output['betas'] = (beta_1, beta_2)
65 return output
68def get(config, module, fallback_module):
69 # type: (Getter, str, str) -> Any
70 '''
71 Given a config and set of modules return an instance or function.
73 Args:
74 config (dict): Instance config.
75 module (str): Always __name__.
76 fallback_module (str): Fallback module, either a tf or torch module.
78 Raises:
79 EnforceError: If config is not a dict with a name key.
81 Returns:
82 object: Instance or function.
83 '''
84 fict.enforce_getter(config)
85 # --------------------------------------------------------------------------
87 config = resolve_config(config)
88 name = config.pop('name')
89 try:
90 return fict.get_module_class(name, module)
91 except NotImplementedError:
92 mod = fict.get_module(fallback_module)
93 return getattr(mod, name)(**config)
96# CALLBACKS---------------------------------------------------------------------
97class ModelCheckpoint:
98 '''
99 Class for saving PyTorch models.
100 '''
101 def __init__(self, filepath, save_freq='epoch', **kwargs):
102 # type: (Filepath, str, Any) -> None
103 '''
104 Constructs ModelCheckpoint instance.
106 Args:
107 filepath (str or Path): Filepath pattern.
108 save_freq (str, optional): Save frequency. Default: epoch.
109 '''
110 self._filepath = Path(filepath).as_posix()
111 self.save_freq = save_freq
113 def save(self, model, epoch):
114 # type: (torch.nn.Module, int) -> None
115 '''
116 Save PyTorch model.
118 Args:
119 model (torch.nn.Module): Model to be saved.
120 epoch (int): Current epoch.
121 '''
122 filepath = self._filepath.format(epoch=epoch)
123 if Path(filepath).suffix == '.safetensors':
124 safetensors.save_model(model, filepath)
125 else:
126 torch.save(model, filepath)
129Callbacks = dict[str, SummaryWriter | ModelCheckpoint]
132def get_callbacks(log_directory, checkpoint_pattern, checkpoint_params={}):
133 # type: (Filepath, str, dict) -> Callbacks
134 '''
135 Create a list of callbacks for Tensoflow model.
137 Args:
138 log_directory (str or Path): Tensorboard project log directory.
139 checkpoint_pattern (str): Filepath pattern for checkpoint callback.
140 checkpoint_params (dict, optional): Params to be passed to checkpoint
141 callback. Default: {}.
143 Raises:
144 EnforceError: If log directory does not exist.
145 EnforeError: If checkpoint pattern does not contain '{epoch}'.
147 Returns:
148 list: Tensorboard and ModelCheckpoint callbacks.
149 '''
150 fict.enforce_callbacks(log_directory, checkpoint_pattern)
151 return dict(
152 tensorboard=SummaryWriter(log_dir=log_directory),
153 checkpoint=ModelCheckpoint(checkpoint_pattern, **checkpoint_params),
154 )
157# DATASET-----------------------------------------------------------------------
158class TorchDataset(Dataset, torchdata.Dataset):
159 '''
160 Class for inheriting torch Dataset into flatiron Dataset.
161 '''
162 @staticmethod
163 def monkey_patch(dataset, channels_first=True):
164 # type: (Dataset, bool) -> TorchDataset
165 '''
166 Construct and monkey patch a new TorchDataset instance from a given
167 Dataset.
168 Pytorch expects images in with the shape (C, H , W) per default.
170 Args:
171 dataset (Dataset): Dataset.
172 channels_first (bool, optional): Will convert any matrix of shape
173 (H, W, C) into (C, H, W). Default: True.
175 Returns:
176 TorchDataset: TorchDataset instance.
177 '''
178 this = TorchDataset(dataset.info)
179 this._info = dataset._info.copy()
180 this._info['frame'] = this._info.index
181 this.data = dataset.data
182 this.labels = dataset.labels
183 this.label_axis = dataset.label_axis
184 this._ext_regex = dataset._ext_regex
185 this._calc_file_size = dataset._calc_file_size
186 this._sample_gb = dataset._sample_gb
187 this._channels_first = channels_first # type: ignore
188 return this
190 def __getitem__(self, frame):
191 # type: (int) -> list[torch.Tensor]
192 '''
193 Get tensor data by frame.
195 Returns:
196 lis[torch.Tensor]: List of Tensors.
197 '''
198 items = self.get_arrays(frame)
200 # pytorch warns about arrays not being writable, this fixes that
201 items = [x.copy() for x in items]
203 # pytorch expects (C, H, W) because it sucks
204 if self._channels_first: # type: ignore
205 arrays = items
206 items = []
207 for item in arrays:
208 if item.ndim == 3:
209 item = np.transpose(item, (2, 0, 1))
210 items.append(item)
212 output = list(map(torch.from_numpy, items))
213 return output
216# COMPILE-----------------------------------------------------------------------
217def pre_build(device):
218 pass
221def compile(
222 framework, # type: Getter
223 model, # type: Any
224 optimizer, # type: Getter
225 loss, # type: Getter
226 metrics, # type: list[Getter]
227):
228 # type: (...) -> Getter
229 '''
230 Call `torch.compile` on given model with kwargs.
232 Args:
233 framework (dict): Framework dict.
234 model (Any): Model to be compiled.
235 optimizer (dict): Optimizer config for compilation.
236 loss (str): Loss to be compiled.
237 metrics (list[str]): Metrics function to be compiled.
239 Returns:
240 dict: Dict of compiled objects.
241 '''
242 kwargs = dict(filter(
243 lambda x: x[0] not in ['name', 'device'], framework.items()
244 ))
245 return dict(
246 framework=framework,
247 model=torch.compile(model, **kwargs),
248 optimizer=fi_torchoptim.get(optimizer, model),
249 loss=fi_torchloss.get(loss),
250 metrics=[fi_torchmetric.get(m) for m in metrics],
251 )
254# TRAIN-------------------------------------------------------------------------
255def _execute_epoch(
256 epoch, # type: int
257 model, # type: torch.nn.Module
258 data_loader, # type: torch.utils.data.DataLoader
259 optimizer, # type: torch.optim.Optimizer
260 loss_func, # type: torch.nn.Module
261 device, # type: torch.device
262 metrics_funcs=[], # type: list[Callable]
263 writer=None, # type: Optional[SummaryWriter]
264 checkpoint=None, # type: Optional[ModelCheckpoint]
265 mode='train', # type: str
266):
267 # type: (...) -> None
268 '''
269 Execute train or test epoch on given torch model.
271 Args:
272 epoch (int): Current epoch.
273 model (torch.nn.Module): Torch model.
274 data_loader (torch.utils.data.DataLoader): Torch data loader.
275 optimizer (torch.optim.Optimizer): Torch optimizer.
276 loss_func (torch.nn.Module): Torch loss function.
277 metrics_funcs (list[Callable], optional): List of torch metrics.
278 Default: [].
279 writer (SummaryWriter, optional): Tensorboard writer. Default: None.
280 checkpoint (ModelCheckpoint, optional): Model saver. Default: None.
281 device (torch.device): Torch device.
282 mode (str, optional): Mode to execute. Options: [train, test].
283 Default: train.
284 '''
285 if mode == 'train':
286 context = torch.enable_grad # type: Any
287 model.train()
288 elif mode == 'test':
289 context = torch.inference_mode
290 model.eval()
291 else:
292 raise ValueError(f'Invalid mode: {mode}.')
294 # checkpoint mode
295 checkpoint_mode = checkpoint is not None and checkpoint.save_freq == 'batch'
297 metrics = []
298 epoch_size = len(data_loader)
299 with context():
300 for i, batch in enumerate(data_loader):
301 # get x and y
302 if len(batch) == 2:
303 x, y = batch
304 x = x.to(device)
305 y = y.to(device)
306 else:
307 x = batch
308 x = x.to(device)
309 y = x
311 y_pred = model(x)
312 loss = loss_func(y_pred, y)
314 # train model on batch
315 if mode == 'train':
316 optimizer.zero_grad()
317 loss.backward()
318 optimizer.step()
320 # gather batch metrics
321 batch_metrics = dict(loss=loss)
322 for metric in metrics_funcs:
323 key = lbt.to_snakecase(metric.__class__.__name__)
324 batch_metrics[key] = metric(y_pred, y)
325 metrics.append(batch_metrics)
327 # write batch metrics
328 if writer is not None:
329 batch_index = epoch * epoch_size + i
330 for key, val in batch_metrics.items():
331 writer.add_scalar(f'batch_{mode}_{key}', val, batch_index)
333 # save model
334 if checkpoint_mode:
335 checkpoint.save(model, epoch) # type: ignore
337 # write mean epoch metrics
338 if writer is not None:
339 epoch_metrics = pd.DataFrame(metrics) \
340 .map(lambda x: x.cpu().detach().numpy().mean()) \
341 .rename(lambda x: f'epoch_{mode}_{x}', axis=1) \
342 .mean() \
343 .to_dict()
345 for key, val in epoch_metrics.items():
346 writer.add_scalar(key, val, epoch)
349def train(
350 compiled, # type: Compiled
351 callbacks, # type: Callbacks
352 train_data, # type: Dataset
353 test_data, # type: Dataset
354 params, # type: dict
355):
356 # type: (...) -> None
357 '''
358 Train Torch model.
360 Args:
361 compiled (dict): Compiled objects.
362 callbacks (dict): Dict of callbacks.
363 train_data (Dataset): Training dataset.
364 test_data (Dataset): Test dataset.
365 params (dict): Training params.
366 '''
367 framework = compiled['framework']
368 model = compiled['model']
369 optimizer = compiled['optimizer']
370 loss = compiled['loss']
371 metrics = compiled['metrics']
372 checkpoint = callbacks['checkpoint'] # type: Any
373 writer = callbacks['tensorboard']
374 batch_size = params['batch_size']
376 device = torch.device(framework['device'])
377 torch.manual_seed(params['seed'])
378 model = model.to(device)
379 loss = loss.to(device)
380 metrics = [x.to(device) for x in metrics]
382 train_loader = torchdata.DataLoader(
383 TorchDataset.monkey_patch(train_data), batch_size=batch_size
384 ) # type: torchdata.DataLoader
385 test_loader = torchdata.DataLoader(
386 TorchDataset.monkey_patch(test_data), batch_size=batch_size
387 ) # type: torchdata.DataLoader
389 kwargs = dict(
390 model=model,
391 optimizer=optimizer,
392 loss_func=loss,
393 device=device,
394 metrics_funcs=metrics,
395 writer=writer,
396 )
397 for i in tqdm.trange(params['epochs']):
398 _execute_epoch(
399 epoch=i, mode='train', data_loader=train_loader,
400 checkpoint=checkpoint, **kwargs
401 )
402 _execute_epoch(epoch=i, mode='test', data_loader=test_loader, **kwargs)
403 if checkpoint.save_freq == 'epoch':
404 checkpoint.save(model, i)