Coverage for /home/ubuntu/flatiron/python/flatiron/torch/config.py: 100%

340 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Optional, Union 

2from typing_extensions import Annotated 

3from flatiron.core.types import OptBool, Ints, OptInt, OptInts, OptStr 

4from flatiron.core.types import Floats, OptFloat, OptListFloat, OptPairFloat 

5 

6import pydantic as pyd 

7# ------------------------------------------------------------------------------ 

8 

9_DEVICE_RE = '^(cpu|cuda|ipu|xpu|mkldnn|opengl|opencl|ideep|hip|ve|fpga|maia' 

10_DEVICE_RE += '|xla|lazy|vulkan|mps|meta|hpu|mtia|privateuseone)$' 

11_TOKENIZE_RE = '^(none|13a|zh|intl|char|ja-mecab|ko-mecab|flores101|flores200)$' 

12ReductionType = Annotated[str, pyd.Field(pattern='^(mean|sum|none)$')] 

13 

14 

15# BASE-------------------------------------------------------------------------- 

16class TorchBaseConfig(pyd.BaseModel): 

17 name: str 

18 

19 

20# FRAMEWORK--------------------------------------------------------------------- 

21class TorchFramework(pyd.BaseModel): 

22 ''' 

23 Configuration for calls to torch train function. 

24 

25 Attributes: 

26 name (str): Framework name. Default: 'torch'. 

27 device (str, optional): Hardware device. Default: 'cuda'. 

28 ''' 

29 name: str = 'torch' 

30 device: Annotated[str, pyd.Field(pattern=_DEVICE_RE)] = 'cuda' 

31 

32 

33# OPTIMIZER-HELPERS------------------------------------------------------------- 

34class TorchOptBaseConfig(TorchBaseConfig): 

35 learning_rate: float = 0.01 # convert to lr 

36 

37 

38class TMax(pyd.BaseModel): 

39 maximize: bool = False 

40 

41 

42class TFor(pyd.BaseModel): 

43 foreach: OptBool = None 

44 

45 

46class TDiff(pyd.BaseModel): 

47 differentiable: bool = False 

48 

49 

50class TEps(pyd.BaseModel): 

51 epsilon: float = 1e-06 # convert to eps 

52 

53 

54class TCap(pyd.BaseModel): 

55 capturable: bool = False 

56 

57 

58class TDecay(pyd.BaseModel): 

59 weight_decay: float = 0 

60 

61 

62class TBeta(pyd.BaseModel): 

63 beta_1: float = 0.9 

64 beta_2: float = 0.999 # convert to betas: tuple[float, float] 

65 

66 

67class TGroup1(TCap, TDecay, TDiff, TEps, TFor, TMax): 

68 pass 

69 

70 

71# LOSS-HELPERS------------------------------------------------------------------ 

72class TReduct(pyd.BaseModel): 

73 reduction: ReductionType = 'mean' 

74 

75 

76class TRed(pyd.BaseModel): 

77 reduce: OptBool = None 

78 

79 

80class TSize(pyd.BaseModel): 

81 size_average: OptBool = None 

82 

83 

84class TMarg(pyd.BaseModel): 

85 margin: float = 0.0 

86 

87 

88class TGroup2(TRed, TReduct, TSize): 

89 pass 

90 

91 

92class TGroup3(TMarg, TRed, TReduct, TSize): 

93 pass 

94 

95 

96# METRIC-HELPERS---------------------------------------------------------------- 

97class TInd(pyd.BaseModel): 

98 ignore_index: OptInt = None 

99 

100 

101class TNan(pyd.BaseModel): 

102 nan_strategy: Union[ 

103 float, Annotated[str, pyd.Field(pattern='^(error|warn|ignore|disable)$')] 

104 ] = 'warn' 

105 

106 

107class TAct(pyd.BaseModel): 

108 empty_target_action: Annotated[ 

109 str, pyd.Field(pattern='^(error|skip|neg|pos)$') 

110 ] = 'neg' 

111 

112 

113class TOut(pyd.BaseModel): 

114 num_outputs: int = 1 

115 

116 

117class TMReduct(pyd.BaseModel): 

118 reduction: Annotated[ 

119 str, pyd.Field(pattern='^(elementwise_mean|sum|none)$') 

120 ] = 'elementwise_mean' 

121 

122 

123class TTopK(pyd.BaseModel): 

124 top_k: OptInt = None 

125 

126 

127class TCls(pyd.BaseModel): 

128 num_classes: OptInt = None # has multiple signatures 

129 

130 

131class TDate(pyd.BaseModel): 

132 data_range: OptPairFloat = None 

133 

134 

135class TNanStrategy(pyd.BaseModel): 

136 nan_strategy: Annotated[str, pyd.Field(pattern='^(replace|drop)$')] = 'replace' 

137 

138 

139# OPTIMIZER--------------------------------------------------------------------- 

140class TorchOptASGD(TorchOptBaseConfig, TCap, TDecay, TDiff, TFor, TMax): 

141 alpha: float = 0.75 

142 lambd: float = 0.0001 

143 t0: float = 1000000.0 

144 

145 

146class TorchOptAdadelta(TorchOptBaseConfig, TGroup1): 

147 rho: float = 0.9 

148 

149 

150class TorchOptAdafactor(TorchOptBaseConfig, TDecay, TEps, TFor, TMax): 

151 beta2_decay: float = -0.8 

152 clipping_threshold: float = 1.0 # convert to d 

153 

154 

155class TorchOptAdagrad(TorchOptBaseConfig, TDecay, TDiff, TEps, TFor, TMax): 

156 fused: OptBool = None 

157 initial_accumulator_value: float = 0 

158 lr_decay: float = 0 

159 

160 

161class TorchOptAdam(TorchOptBaseConfig, TGroup1, TBeta): 

162 amsgrad: bool = False 

163 fused: OptBool = None 

164 

165 

166class TorchOptAdamW(TorchOptBaseConfig, TGroup1, TBeta): 

167 amsgrad: bool = False 

168 fused: OptBool = None 

169 

170 

171class TorchOptAdamax(TorchOptBaseConfig, TGroup1, TBeta): 

172 pass 

173 

174 

175class TorchOptLBFGS(TorchOptBaseConfig): 

176 history_size: int = 100 

177 line_search_fn: OptStr = None 

178 max_eval: OptInt = None 

179 max_iter: int = 20 

180 tolerance_change: float = 1e-09 

181 tolerance_grad: float = 1e-07 

182 

183 

184class TorchOptNAdam(TorchOptBaseConfig, TGroup1, TBeta): 

185 momentum_decay: float = 0.004 

186 

187 

188class TorchOptRAdam(TorchOptBaseConfig, TGroup1, TBeta): 

189 pass 

190 

191 

192class TorchOptRMSprop(TorchOptBaseConfig, TGroup1): 

193 alpha: float = 0.99 

194 centered: bool = False 

195 momentum: float = 0 

196 

197 

198class TorchOptRprop(TorchOptBaseConfig, TCap, TDiff, TFor, TMax): 

199 etas: tuple[float, float] = (0.5, 1.2) 

200 step_sizes: tuple[float, float] = (1e-06, 50) 

201 

202 

203class TorchOptSGD(TorchOptBaseConfig, TDecay, TDiff, TFor, TMax): 

204 dampening: float = 0 

205 fused: OptBool = None 

206 momentum: float = 0 

207 nesterov: bool = False 

208 

209 

210class TorchOptSparseAdam(TorchOptBaseConfig, TEps, TMax, TBeta): 

211 pass 

212 

213 

214# LOSS-------------------------------------------------------------------------- 

215class TorchLossBCELoss(TorchBaseConfig, TGroup2): 

216 pass 

217 

218 

219class TorchLossBCEWithLogitsLoss(TorchBaseConfig, TGroup2): 

220 pass 

221 

222 

223class TorchLossCTCLoss(TorchBaseConfig, TReduct): 

224 blank: int = 0 

225 zero_infinity: bool = False 

226 

227 

228class TorchLossCosineEmbeddingLoss(TorchBaseConfig, TGroup3): 

229 pass 

230 

231 

232class TorchLossCrossEntropyLoss(TorchBaseConfig, TGroup2): 

233 ignore_index: int = -100 

234 label_smoothing: float = 0.0 

235 

236 

237class TorchLossGaussianNLLLoss(TorchBaseConfig, TEps, TReduct): 

238 full: bool = False 

239 

240 

241class TorchLossHingeEmbeddingLoss(TorchBaseConfig, TGroup3): 

242 pass 

243 

244 

245class TorchLossHuberLoss(TorchBaseConfig, TReduct): 

246 delta: float = 1.0 

247 

248 

249class TorchLossKLDivLoss(TorchBaseConfig, TGroup2): 

250 log_target: bool = False 

251 

252 

253class TorchLossL1Loss(TorchBaseConfig, TGroup2): 

254 pass 

255 

256 

257class TorchLossMSELoss(TorchBaseConfig, TGroup2): 

258 pass 

259 

260 

261class TorchLossMarginRankingLoss(TorchBaseConfig, TGroup3): 

262 pass 

263 

264 

265class TorchLossMultiLabelMarginLoss(TorchBaseConfig, TGroup2): 

266 pass 

267 

268 

269class TorchLossMultiLabelSoftMarginLoss(TorchBaseConfig, TGroup2): 

270 pass 

271 

272 

273class TorchLossMultiMarginLoss(TorchBaseConfig, TGroup3): 

274 exponent: int = 1 # convert to p 

275 

276 

277class TorchLossNLLLoss(TorchBaseConfig, TGroup2): 

278 ignore_index: int = -100 

279 

280 

281class TorchLossPairwiseDistance(TorchBaseConfig, TEps): 

282 keepdim: bool = False 

283 norm_degree: float = 2.0 # convert to p 

284 

285 

286class TorchLossPoissonNLLLoss(TorchBaseConfig, TEps, TGroup2): 

287 full: bool = False 

288 log_input: bool = True 

289 

290 

291class TorchLossSmoothL1Loss(TorchBaseConfig, TGroup2): 

292 beta: float = 1.0 

293 

294 

295class TorchLossSoftMarginLoss(TorchBaseConfig, TGroup2): 

296 pass 

297 

298 

299class TorchLossTripletMarginLoss(TorchBaseConfig, TEps, TGroup3): 

300 norm_degree: float = 2.0 # convert to p 

301 swap: bool = False 

302 

303 

304class TorchLossTripletMarginWithDistanceLoss(TorchBaseConfig, TMarg, TReduct): 

305 swap: bool = False 

306 

307 

308# METRICS----------------------------------------------------------------------- 

309class TorchMetricBLEUScore(TorchBaseConfig): 

310 n_gram: int = 4 

311 smooth: bool = False 

312 weights: OptListFloat = None 

313 

314 

315class TorchMetricCHRFScore(TorchBaseConfig): 

316 beta: float = 2.0 

317 lowercase: bool = False 

318 n_char_order: int = 6 

319 n_word_order: int = 2 

320 return_sentence_level_score: bool = False 

321 whitespace: bool = False 

322 

323 

324class TorchMetricCatMetric(TorchBaseConfig, TNan): 

325 pass 

326 

327 

328class TorchMetricConcordanceCorrCoef(TorchBaseConfig, TOut): 

329 pass 

330 

331 

332class TorchMetricCosineSimilarity(TorchBaseConfig): 

333 reduction: ReductionType = 'sum' 

334 

335 

336class TorchMetricCramersV(TorchBaseConfig, TCls, TNanStrategy): 

337 bias_correction: bool = True 

338 nan_replace_value: OptFloat = 0.0 

339 

340 

341class TorchMetricCriticalSuccessIndex(TorchBaseConfig): 

342 keep_sequence_dim: OptInt = None 

343 threshold: float 

344 

345 

346class TorchMetricDice(TorchBaseConfig, TCls, TInd, TTopK): 

347 average: Optional[Annotated[ 

348 str, pyd.Field(pattern='^(micro|macro|weighted|samples|none)$') 

349 ]] = 'micro' 

350 mdmc_average: Optional[Annotated[ 

351 str, pyd.Field(pattern='^(samplewise|global)$') 

352 ]] = 'global' 

353 multiclass: OptBool = None 

354 threshold: float = 0.5 

355 zero_division: int = 0 

356 

357 

358class TorchMetricErrorRelativeGlobalDimensionlessSynthesis(TorchBaseConfig, TMReduct): 

359 ratio: float = 4 

360 

361 

362class TorchMetricExplainedVariance(TorchBaseConfig): 

363 multioutput: Annotated[ 

364 str, pyd.Field(pattern='^(raw_values|uniform_average|variance_weighted)$') 

365 ] = 'uniform_average' 

366 

367 

368class TorchMetricExtendedEditDistance(TorchBaseConfig): 

369 alpha: float = 2.0 

370 deletion: float = 0.2 

371 insertion: float = 1.0 

372 language: Annotated[str, pyd.Field(pattern='^(en|ja)$')] = 'en' 

373 return_sentence_level_score: bool = False 

374 rho: float = 0.3 

375 

376 

377class TorchMetricFleissKappa(TorchBaseConfig): 

378 mode: Annotated[str, pyd.Field(pattern='^(counts|probs)$')] = 'counts' 

379 

380 

381class TorchMetricKLDivergence(TorchBaseConfig): 

382 log_prob: bool = False 

383 reduction: ReductionType = 'mean' 

384 

385 

386class TorchMetricKendallRankCorrCoef(TorchBaseConfig, TOut): 

387 alternative: Optional[Annotated[ 

388 str, pyd.Field(pattern='^(two-sided|less|greater)$') 

389 ]] = 'two-sided' 

390 t_test: bool = False 

391 variant: Annotated[str, pyd.Field(pattern='^(a|b|c)$')] = 'b' 

392 

393 

394class TorchMetricLogCoshError(TorchBaseConfig, TOut): 

395 pass 

396 

397 

398class TorchMetricMaxMetric(TorchBaseConfig, TNan): 

399 pass 

400 

401 

402class TorchMetricMeanAbsoluteError(TorchBaseConfig, TOut): 

403 pass 

404 

405 

406class TorchMetricMeanMetric(TorchBaseConfig, TNan): 

407 pass 

408 

409 

410class TorchMetricMeanSquaredError(TorchBaseConfig, TOut): 

411 squared: bool = True 

412 

413 

414class TorchMetricMinMetric(TorchBaseConfig, TNan): 

415 pass 

416 

417 

418class TorchMetricMinkowskiDistance(TorchBaseConfig): 

419 p: float 

420 

421 

422class TorchMetricModifiedPanopticQuality(TorchBaseConfig): 

423 allow_unknown_preds_category: bool = False 

424 stuffs: list[int] 

425 things: list[int] 

426 

427 

428class TorchMetricMultiScaleStructuralSimilarityIndexMeasure(TorchBaseConfig, TMReduct, TDate): 

429 betas: tuple = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) 

430 gaussian_kernel: bool = True 

431 k1: float = 0.01 

432 k2: float = 0.03 

433 kernel_size: Ints = 11 

434 normalize: Optional[Annotated[ 

435 str, pyd.Field(pattern='^(relu|simple)$') 

436 ]] = 'relu' 

437 sigma: Floats = 1.5 

438 

439 

440class TorchMetricNormalizedRootMeanSquaredError(TorchBaseConfig, TOut): 

441 normalization: Annotated[ 

442 str, pyd.Field(pattern='^(mean|range|std|l2)$') 

443 ] = 'mean' 

444 

445 

446class TorchMetricPanopticQuality(TorchBaseConfig): 

447 allow_unknown_preds_category: bool = False 

448 stuffs: list[int] 

449 things: list[int] 

450 

451 

452class TorchMetricPeakSignalNoiseRatio(TorchBaseConfig, TMReduct, TDate): 

453 base: float = 10.0 

454 dim: OptInts = None 

455 

456 

457class TorchMetricPearsonCorrCoef(TorchBaseConfig, TOut): 

458 pass 

459 

460 

461class TorchMetricPearsonsContingencyCoefficient(TorchBaseConfig, TNanStrategy): 

462 nan_replace_value: OptFloat = 0.0 

463 num_classes: int 

464 

465 

466class TorchMetricPermutationInvariantTraining(TorchBaseConfig): 

467 eval_func: Annotated[str, pyd.Field(pattern='^(max|min)$')] = 'max' 

468 mode: Annotated[ 

469 str, pyd.Field(pattern='^(speaker-wise|permutation-wise)$') 

470 ] = 'speaker-wise' 

471 

472 

473class TorchMetricPerplexity(TorchBaseConfig, TInd): 

474 pass 

475 

476 

477class TorchMetricR2Score(TorchBaseConfig): 

478 adjusted: int = 0 

479 multioutput: Annotated[ 

480 str, pyd.Field(pattern='^(raw_values|uniform_average|variance_weighted)$') 

481 ] = 'uniform_average' 

482 

483 

484class TorchMetricRelativeAverageSpectralError(TorchBaseConfig): 

485 window_size: int = 8 

486 

487 

488class TorchMetricRelativeSquaredError(TorchBaseConfig, TOut): 

489 squared: bool = True 

490 

491 

492class TorchMetricRetrievalFallOut(TorchBaseConfig, TInd, TTopK): 

493 empty_target_action: Annotated[ 

494 str, pyd.Field(pattern='^(error|skip|neg|pos)$') 

495 ] = 'pos' 

496 

497 

498class TorchMetricRetrievalHitRate(TorchBaseConfig, TAct, TInd, TTopK): 

499 pass 

500 

501 

502class TorchMetricRetrievalMAP(TorchBaseConfig, TAct, TInd, TTopK): 

503 pass 

504 

505 

506class TorchMetricRetrievalMRR(TorchBaseConfig, TAct, TInd): 

507 pass 

508 

509 

510class TorchMetricRetrievalNormalizedDCG(TorchBaseConfig, TAct, TInd, TTopK): 

511 pass 

512 

513 

514class TorchMetricRetrievalPrecision(TorchBaseConfig, TAct, TInd, TTopK): 

515 adaptive_k: bool = False 

516 

517 

518class TorchMetricRetrievalPrecisionRecallCurve(TorchBaseConfig, TInd): 

519 adaptive_k: bool = False 

520 max_k: OptInt = None 

521 

522 

523class TorchMetricRetrievalRPrecision(TorchBaseConfig, TAct, TInd): 

524 pass 

525 

526 

527class TorchMetricRetrievalRecall(TorchBaseConfig, TAct, TInd, TTopK): 

528 pass 

529 

530 

531class TorchMetricRetrievalRecallAtFixedPrecision(TorchBaseConfig, TAct, TInd): 

532 adaptive_k: bool = False 

533 max_k: OptInt = None 

534 min_precision: float = 0.0 

535 

536 

537class TorchMetricRootMeanSquaredErrorUsingSlidingWindow(TorchBaseConfig): 

538 window_size: int = 8 

539 

540 

541class TorchMetricRunningMean(TorchBaseConfig, TNan): 

542 window: int = 5 

543 

544 

545class TorchMetricRunningSum(TorchBaseConfig, TNan): 

546 window: int = 5 

547 

548 

549class TorchMetricSacreBLEUScore(TorchBaseConfig): 

550 lowercase: bool = False 

551 n_gram: int = 4 

552 smooth: bool = False 

553 tokenize: Annotated[str, pyd.Field(pattern=_TOKENIZE_RE)] = '13a' 

554 weights: OptListFloat = None 

555 

556 

557class TorchMetricScaleInvariantSignalDistortionRatio(TorchBaseConfig): 

558 zero_mean: bool = False 

559 

560 

561class TorchMetricSignalDistortionRatio(TorchBaseConfig): 

562 filter_length: int = 512 

563 load_diag: OptFloat = None 

564 use_cg_iter: OptInt = None 

565 zero_mean: bool = False 

566 

567 

568class TorchMetricSignalNoiseRatio(TorchBaseConfig): 

569 zero_mean: bool = False 

570 

571 

572class TorchMetricSpearmanCorrCoef(TorchBaseConfig, TOut): 

573 pass 

574 

575 

576class TorchMetricSpectralAngleMapper(TorchBaseConfig, TMReduct): 

577 pass 

578 

579 

580class TorchMetricSpectralDistortionIndex(TorchBaseConfig, TMReduct): 

581 p: int = 1 

582 

583 

584class TorchMetricStructuralSimilarityIndexMeasure(TorchBaseConfig, TMReduct): 

585 data_range: OptPairFloat = None 

586 gaussian_kernel: bool = True 

587 k1: float = 0.01 

588 k2: float = 0.03 

589 kernel_size: Ints = 11 

590 return_contrast_sensitivity: bool = False 

591 return_full_image: bool = False 

592 sigma: Floats = 1.5 

593 

594 

595class TorchMetricSumMetric(TorchBaseConfig, TNan): 

596 pass 

597 

598 

599class TorchMetricTheilsU(TorchBaseConfig, TNanStrategy): 

600 nan_replace_value: OptFloat = 0.0 

601 num_classes: int 

602 

603 

604class TorchMetricTotalVariation(TorchBaseConfig): 

605 reduction: ReductionType = 'sum' 

606 

607 

608class TorchMetricTranslationEditRate(TorchBaseConfig): 

609 asian_support: bool = False 

610 lowercase: bool = True 

611 no_punctuation: bool = False 

612 normalize: bool = False 

613 return_sentence_level_score: bool = False 

614 

615 

616class TorchMetricTschuprowsT(TorchBaseConfig, TNanStrategy): 

617 bias_correction: bool = True 

618 nan_replace_value: OptFloat = 0.0 

619 num_classes: int 

620 

621 

622class TorchMetricTweedieDevianceScore(TorchBaseConfig): 

623 power: float = 0.0 

624 

625 

626class TorchMetricUniversalImageQualityIndex(TorchBaseConfig, TMReduct): 

627 kernel_size: tuple[int, ...] = (11, 11) 

628 sigma: tuple[float, ...] = (1.5, 1.5)