Coverage for /home/ubuntu/flatiron/python/flatiron/torch/tools.py: 98%

148 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Any, Callable, Optional # noqa F401 

2from flatiron.core.dataset import Dataset # noqa: F401 

3from flatiron.core.types import Compiled, Filepath, Getter # noqa F401 

4 

5from copy import deepcopy 

6from pathlib import Path 

7 

8from torch.utils.tensorboard import SummaryWriter 

9import lunchbox.tools as lbt 

10import pandas as pd 

11import numpy as np 

12import safetensors.torch as safetensors 

13import tqdm.notebook as tqdm 

14import torch 

15import torch.utils.data as torchdata 

16 

17import flatiron.core.tools as fict 

18import flatiron.torch.loss as fi_torchloss 

19import flatiron.torch.metric as fi_torchmetric 

20import flatiron.torch.optimizer as fi_torchoptim 

21# ------------------------------------------------------------------------------ 

22 

23 

24def resolve_config(config): 

25 # type: (dict) -> dict 

26 ''' 

27 Resolve configs handed to Torch classes. Replaces the following: 

28 

29 * learning_rate 

30 * epsilon 

31 * clipping_threshold 

32 * exponent 

33 * norm_degree 

34 * beta_1 

35 * beta_2 

36 

37 Args: 

38 config (dict): Config dict. 

39 

40 Returns: 

41 dict: Resolved config. 

42 ''' 

43 params = config.pop('params', None) 

44 output = deepcopy(config) 

45 if params is not None: 

46 output['params'] = params 

47 

48 lut = dict( 

49 learning_rate='lr', 

50 epsilon='eps', 

51 clipping_threshold='d', 

52 exponent='p', 

53 norm_degree='p', 

54 ) 

55 for key, val in config.items(): 

56 if key in lut: 

57 output[lut[key]] = val 

58 del output[key] 

59 

60 if 'beta_1' in output or 'beta_2' in output: 

61 beta_1 = output.pop('beta_1', 0.9) 

62 beta_2 = output.pop('beta_2', 0.999) 

63 output['betas'] = (beta_1, beta_2) 

64 

65 return output 

66 

67 

68def get(config, module, fallback_module): 

69 # type: (Getter, str, str) -> Any 

70 ''' 

71 Given a config and set of modules return an instance or function. 

72 

73 Args: 

74 config (dict): Instance config. 

75 module (str): Always __name__. 

76 fallback_module (str): Fallback module, either a tf or torch module. 

77 

78 Raises: 

79 EnforceError: If config is not a dict with a name key. 

80 

81 Returns: 

82 object: Instance or function. 

83 ''' 

84 fict.enforce_getter(config) 

85 # -------------------------------------------------------------------------- 

86 

87 config = resolve_config(config) 

88 name = config.pop('name') 

89 try: 

90 return fict.get_module_class(name, module) 

91 except NotImplementedError: 

92 mod = fict.get_module(fallback_module) 

93 return getattr(mod, name)(**config) 

94 

95 

96# CALLBACKS--------------------------------------------------------------------- 

97class ModelCheckpoint: 

98 ''' 

99 Class for saving PyTorch models. 

100 ''' 

101 def __init__(self, filepath, save_freq='epoch', **kwargs): 

102 # type: (Filepath, str, Any) -> None 

103 ''' 

104 Constructs ModelCheckpoint instance. 

105 

106 Args: 

107 filepath (str or Path): Filepath pattern. 

108 save_freq (str, optional): Save frequency. Default: epoch. 

109 ''' 

110 self._filepath = Path(filepath).as_posix() 

111 self.save_freq = save_freq 

112 

113 def save(self, model, epoch): 

114 # type: (torch.nn.Module, int) -> None 

115 ''' 

116 Save PyTorch model. 

117 

118 Args: 

119 model (torch.nn.Module): Model to be saved. 

120 epoch (int): Current epoch. 

121 ''' 

122 filepath = self._filepath.format(epoch=epoch) 

123 if Path(filepath).suffix == '.safetensors': 

124 safetensors.save_model(model, filepath) 

125 else: 

126 torch.save(model, filepath) 

127 

128 

129Callbacks = dict[str, SummaryWriter | ModelCheckpoint] 

130 

131 

132def get_callbacks(log_directory, checkpoint_pattern, checkpoint_params={}): 

133 # type: (Filepath, str, dict) -> Callbacks 

134 ''' 

135 Create a list of callbacks for Tensoflow model. 

136 

137 Args: 

138 log_directory (str or Path): Tensorboard project log directory. 

139 checkpoint_pattern (str): Filepath pattern for checkpoint callback. 

140 checkpoint_params (dict, optional): Params to be passed to checkpoint 

141 callback. Default: {}. 

142 

143 Raises: 

144 EnforceError: If log directory does not exist. 

145 EnforeError: If checkpoint pattern does not contain '{epoch}'. 

146 

147 Returns: 

148 list: Tensorboard and ModelCheckpoint callbacks. 

149 ''' 

150 fict.enforce_callbacks(log_directory, checkpoint_pattern) 

151 return dict( 

152 tensorboard=SummaryWriter(log_dir=log_directory), 

153 checkpoint=ModelCheckpoint(checkpoint_pattern, **checkpoint_params), 

154 ) 

155 

156 

157# DATASET----------------------------------------------------------------------- 

158class TorchDataset(Dataset, torchdata.Dataset): 

159 ''' 

160 Class for inheriting torch Dataset into flatiron Dataset. 

161 ''' 

162 @staticmethod 

163 def monkey_patch(dataset, channels_first=True): 

164 # type: (Dataset, bool) -> TorchDataset 

165 ''' 

166 Construct and monkey patch a new TorchDataset instance from a given 

167 Dataset. 

168 Pytorch expects images in with the shape (C, H , W) per default. 

169 

170 Args: 

171 dataset (Dataset): Dataset. 

172 channels_first (bool, optional): Will convert any matrix of shape 

173 (H, W, C) into (C, H, W). Default: True. 

174 

175 Returns: 

176 TorchDataset: TorchDataset instance. 

177 ''' 

178 this = TorchDataset(dataset.info) 

179 this._info = dataset._info.copy() 

180 this._info['frame'] = this._info.index 

181 this.data = dataset.data 

182 this.labels = dataset.labels 

183 this.label_axis = dataset.label_axis 

184 this._ext_regex = dataset._ext_regex 

185 this._calc_file_size = dataset._calc_file_size 

186 this._sample_gb = dataset._sample_gb 

187 this._channels_first = channels_first # type: ignore 

188 return this 

189 

190 def __getitem__(self, frame): 

191 # type: (int) -> list[torch.Tensor] 

192 ''' 

193 Get tensor data by frame. 

194 

195 Returns: 

196 lis[torch.Tensor]: List of Tensors. 

197 ''' 

198 items = self.get_arrays(frame) 

199 

200 # pytorch warns about arrays not being writable, this fixes that 

201 items = [x.copy() for x in items] 

202 

203 # pytorch expects (C, H, W) because it sucks 

204 if self._channels_first: # type: ignore 

205 arrays = items 

206 items = [] 

207 for item in arrays: 

208 if item.ndim == 3: 

209 item = np.transpose(item, (2, 0, 1)) 

210 items.append(item) 

211 

212 output = list(map(torch.from_numpy, items)) 

213 return output 

214 

215 

216# COMPILE----------------------------------------------------------------------- 

217def pre_build(device): 

218 pass 

219 

220 

221def compile( 

222 framework, # type: Getter 

223 model, # type: Any 

224 optimizer, # type: Getter 

225 loss, # type: Getter 

226 metrics, # type: list[Getter] 

227): 

228 # type: (...) -> Getter 

229 ''' 

230 Call `torch.compile` on given model with kwargs. 

231 

232 Args: 

233 framework (dict): Framework dict. 

234 model (Any): Model to be compiled. 

235 optimizer (dict): Optimizer config for compilation. 

236 loss (str): Loss to be compiled. 

237 metrics (list[str]): Metrics function to be compiled. 

238 

239 Returns: 

240 dict: Dict of compiled objects. 

241 ''' 

242 kwargs = dict(filter( 

243 lambda x: x[0] not in ['name', 'device'], framework.items() 

244 )) 

245 return dict( 

246 framework=framework, 

247 model=torch.compile(model, **kwargs), 

248 optimizer=fi_torchoptim.get(optimizer, model), 

249 loss=fi_torchloss.get(loss), 

250 metrics=[fi_torchmetric.get(m) for m in metrics], 

251 ) 

252 

253 

254# TRAIN------------------------------------------------------------------------- 

255def _execute_epoch( 

256 epoch, # type: int 

257 model, # type: torch.nn.Module 

258 data_loader, # type: torch.utils.data.DataLoader 

259 optimizer, # type: torch.optim.Optimizer 

260 loss_func, # type: torch.nn.Module 

261 device, # type: torch.device 

262 metrics_funcs=[], # type: list[Callable] 

263 writer=None, # type: Optional[SummaryWriter] 

264 checkpoint=None, # type: Optional[ModelCheckpoint] 

265 mode='train', # type: str 

266): 

267 # type: (...) -> None 

268 ''' 

269 Execute train or test epoch on given torch model. 

270 

271 Args: 

272 epoch (int): Current epoch. 

273 model (torch.nn.Module): Torch model. 

274 data_loader (torch.utils.data.DataLoader): Torch data loader. 

275 optimizer (torch.optim.Optimizer): Torch optimizer. 

276 loss_func (torch.nn.Module): Torch loss function. 

277 metrics_funcs (list[Callable], optional): List of torch metrics. 

278 Default: []. 

279 writer (SummaryWriter, optional): Tensorboard writer. Default: None. 

280 checkpoint (ModelCheckpoint, optional): Model saver. Default: None. 

281 device (torch.device): Torch device. 

282 mode (str, optional): Mode to execute. Options: [train, test]. 

283 Default: train. 

284 ''' 

285 if mode == 'train': 

286 context = torch.enable_grad # type: Any 

287 model.train() 

288 elif mode == 'test': 

289 context = torch.inference_mode 

290 model.eval() 

291 else: 

292 raise ValueError(f'Invalid mode: {mode}.') 

293 

294 # checkpoint mode 

295 checkpoint_mode = checkpoint is not None and checkpoint.save_freq == 'batch' 

296 

297 metrics = [] 

298 epoch_size = len(data_loader) 

299 with context(): 

300 for i, batch in enumerate(data_loader): 

301 # get x and y 

302 if len(batch) == 2: 

303 x, y = batch 

304 x = x.to(device) 

305 y = y.to(device) 

306 else: 

307 x = batch 

308 x = x.to(device) 

309 y = x 

310 

311 y_pred = model(x) 

312 loss = loss_func(y_pred, y) 

313 

314 # train model on batch 

315 if mode == 'train': 

316 optimizer.zero_grad() 

317 loss.backward() 

318 optimizer.step() 

319 

320 # gather batch metrics 

321 batch_metrics = dict(loss=loss) 

322 for metric in metrics_funcs: 

323 key = lbt.to_snakecase(metric.__class__.__name__) 

324 batch_metrics[key] = metric(y_pred, y) 

325 metrics.append(batch_metrics) 

326 

327 # write batch metrics 

328 if writer is not None: 

329 batch_index = epoch * epoch_size + i 

330 for key, val in batch_metrics.items(): 

331 writer.add_scalar(f'batch_{mode}_{key}', val, batch_index) 

332 

333 # save model 

334 if checkpoint_mode: 

335 checkpoint.save(model, epoch) # type: ignore 

336 

337 # write mean epoch metrics 

338 if writer is not None: 

339 epoch_metrics = pd.DataFrame(metrics) \ 

340 .map(lambda x: x.cpu().detach().numpy().mean()) \ 

341 .rename(lambda x: f'epoch_{mode}_{x}', axis=1) \ 

342 .mean() \ 

343 .to_dict() 

344 

345 for key, val in epoch_metrics.items(): 

346 writer.add_scalar(key, val, epoch) 

347 

348 

349def train( 

350 compiled, # type: Compiled 

351 callbacks, # type: Callbacks 

352 train_data, # type: Dataset 

353 test_data, # type: Dataset 

354 params, # type: dict 

355): 

356 # type: (...) -> None 

357 ''' 

358 Train Torch model. 

359 

360 Args: 

361 compiled (dict): Compiled objects. 

362 callbacks (dict): Dict of callbacks. 

363 train_data (Dataset): Training dataset. 

364 test_data (Dataset): Test dataset. 

365 params (dict): Training params. 

366 ''' 

367 framework = compiled['framework'] 

368 model = compiled['model'] 

369 optimizer = compiled['optimizer'] 

370 loss = compiled['loss'] 

371 metrics = compiled['metrics'] 

372 checkpoint = callbacks['checkpoint'] # type: Any 

373 writer = callbacks['tensorboard'] 

374 batch_size = params['batch_size'] 

375 

376 device = torch.device(framework['device']) 

377 torch.manual_seed(params['seed']) 

378 model = model.to(device) 

379 loss = loss.to(device) 

380 metrics = [x.to(device) for x in metrics] 

381 

382 train_loader = torchdata.DataLoader( 

383 TorchDataset.monkey_patch(train_data), batch_size=batch_size 

384 ) # type: torchdata.DataLoader 

385 test_loader = torchdata.DataLoader( 

386 TorchDataset.monkey_patch(test_data), batch_size=batch_size 

387 ) # type: torchdata.DataLoader 

388 

389 kwargs = dict( 

390 model=model, 

391 optimizer=optimizer, 

392 loss_func=loss, 

393 device=device, 

394 metrics_funcs=metrics, 

395 writer=writer, 

396 ) 

397 for i in tqdm.trange(params['epochs']): 

398 _execute_epoch( 

399 epoch=i, mode='train', data_loader=train_loader, 

400 checkpoint=checkpoint, **kwargs 

401 ) 

402 _execute_epoch(epoch=i, mode='test', data_loader=test_loader, **kwargs) 

403 if checkpoint.save_freq == 'epoch': 

404 checkpoint.save(model, i)