Coverage for /home/ubuntu/flatiron/python/flatiron/torch/config.py: 100%
340 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Optional, Union
2from typing_extensions import Annotated
3from flatiron.core.types import OptBool, Ints, OptInt, OptInts, OptStr
4from flatiron.core.types import Floats, OptFloat, OptListFloat, OptPairFloat
6import pydantic as pyd
7# ------------------------------------------------------------------------------
9_DEVICE_RE = '^(cpu|cuda|ipu|xpu|mkldnn|opengl|opencl|ideep|hip|ve|fpga|maia'
10_DEVICE_RE += '|xla|lazy|vulkan|mps|meta|hpu|mtia|privateuseone)$'
11_TOKENIZE_RE = '^(none|13a|zh|intl|char|ja-mecab|ko-mecab|flores101|flores200)$'
12ReductionType = Annotated[str, pyd.Field(pattern='^(mean|sum|none)$')]
15# BASE--------------------------------------------------------------------------
16class TorchBaseConfig(pyd.BaseModel):
17 name: str
20# FRAMEWORK---------------------------------------------------------------------
21class TorchFramework(pyd.BaseModel):
22 '''
23 Configuration for calls to torch train function.
25 Attributes:
26 name (str): Framework name. Default: 'torch'.
27 device (str, optional): Hardware device. Default: 'cuda'.
28 '''
29 name: str = 'torch'
30 device: Annotated[str, pyd.Field(pattern=_DEVICE_RE)] = 'cuda'
33# OPTIMIZER-HELPERS-------------------------------------------------------------
34class TorchOptBaseConfig(TorchBaseConfig):
35 learning_rate: float = 0.01 # convert to lr
38class TMax(pyd.BaseModel):
39 maximize: bool = False
42class TFor(pyd.BaseModel):
43 foreach: OptBool = None
46class TDiff(pyd.BaseModel):
47 differentiable: bool = False
50class TEps(pyd.BaseModel):
51 epsilon: float = 1e-06 # convert to eps
54class TCap(pyd.BaseModel):
55 capturable: bool = False
58class TDecay(pyd.BaseModel):
59 weight_decay: float = 0
62class TBeta(pyd.BaseModel):
63 beta_1: float = 0.9
64 beta_2: float = 0.999 # convert to betas: tuple[float, float]
67class TGroup1(TCap, TDecay, TDiff, TEps, TFor, TMax):
68 pass
71# LOSS-HELPERS------------------------------------------------------------------
72class TReduct(pyd.BaseModel):
73 reduction: ReductionType = 'mean'
76class TRed(pyd.BaseModel):
77 reduce: OptBool = None
80class TSize(pyd.BaseModel):
81 size_average: OptBool = None
84class TMarg(pyd.BaseModel):
85 margin: float = 0.0
88class TGroup2(TRed, TReduct, TSize):
89 pass
92class TGroup3(TMarg, TRed, TReduct, TSize):
93 pass
96# METRIC-HELPERS----------------------------------------------------------------
97class TInd(pyd.BaseModel):
98 ignore_index: OptInt = None
101class TNan(pyd.BaseModel):
102 nan_strategy: Union[
103 float, Annotated[str, pyd.Field(pattern='^(error|warn|ignore|disable)$')]
104 ] = 'warn'
107class TAct(pyd.BaseModel):
108 empty_target_action: Annotated[
109 str, pyd.Field(pattern='^(error|skip|neg|pos)$')
110 ] = 'neg'
113class TOut(pyd.BaseModel):
114 num_outputs: int = 1
117class TMReduct(pyd.BaseModel):
118 reduction: Annotated[
119 str, pyd.Field(pattern='^(elementwise_mean|sum|none)$')
120 ] = 'elementwise_mean'
123class TTopK(pyd.BaseModel):
124 top_k: OptInt = None
127class TCls(pyd.BaseModel):
128 num_classes: OptInt = None # has multiple signatures
131class TDate(pyd.BaseModel):
132 data_range: OptPairFloat = None
135class TNanStrategy(pyd.BaseModel):
136 nan_strategy: Annotated[str, pyd.Field(pattern='^(replace|drop)$')] = 'replace'
139# OPTIMIZER---------------------------------------------------------------------
140class TorchOptASGD(TorchOptBaseConfig, TCap, TDecay, TDiff, TFor, TMax):
141 alpha: float = 0.75
142 lambd: float = 0.0001
143 t0: float = 1000000.0
146class TorchOptAdadelta(TorchOptBaseConfig, TGroup1):
147 rho: float = 0.9
150class TorchOptAdafactor(TorchOptBaseConfig, TDecay, TEps, TFor, TMax):
151 beta2_decay: float = -0.8
152 clipping_threshold: float = 1.0 # convert to d
155class TorchOptAdagrad(TorchOptBaseConfig, TDecay, TDiff, TEps, TFor, TMax):
156 fused: OptBool = None
157 initial_accumulator_value: float = 0
158 lr_decay: float = 0
161class TorchOptAdam(TorchOptBaseConfig, TGroup1, TBeta):
162 amsgrad: bool = False
163 fused: OptBool = None
166class TorchOptAdamW(TorchOptBaseConfig, TGroup1, TBeta):
167 amsgrad: bool = False
168 fused: OptBool = None
171class TorchOptAdamax(TorchOptBaseConfig, TGroup1, TBeta):
172 pass
175class TorchOptLBFGS(TorchOptBaseConfig):
176 history_size: int = 100
177 line_search_fn: OptStr = None
178 max_eval: OptInt = None
179 max_iter: int = 20
180 tolerance_change: float = 1e-09
181 tolerance_grad: float = 1e-07
184class TorchOptNAdam(TorchOptBaseConfig, TGroup1, TBeta):
185 momentum_decay: float = 0.004
188class TorchOptRAdam(TorchOptBaseConfig, TGroup1, TBeta):
189 pass
192class TorchOptRMSprop(TorchOptBaseConfig, TGroup1):
193 alpha: float = 0.99
194 centered: bool = False
195 momentum: float = 0
198class TorchOptRprop(TorchOptBaseConfig, TCap, TDiff, TFor, TMax):
199 etas: tuple[float, float] = (0.5, 1.2)
200 step_sizes: tuple[float, float] = (1e-06, 50)
203class TorchOptSGD(TorchOptBaseConfig, TDecay, TDiff, TFor, TMax):
204 dampening: float = 0
205 fused: OptBool = None
206 momentum: float = 0
207 nesterov: bool = False
210class TorchOptSparseAdam(TorchOptBaseConfig, TEps, TMax, TBeta):
211 pass
214# LOSS--------------------------------------------------------------------------
215class TorchLossBCELoss(TorchBaseConfig, TGroup2):
216 pass
219class TorchLossBCEWithLogitsLoss(TorchBaseConfig, TGroup2):
220 pass
223class TorchLossCTCLoss(TorchBaseConfig, TReduct):
224 blank: int = 0
225 zero_infinity: bool = False
228class TorchLossCosineEmbeddingLoss(TorchBaseConfig, TGroup3):
229 pass
232class TorchLossCrossEntropyLoss(TorchBaseConfig, TGroup2):
233 ignore_index: int = -100
234 label_smoothing: float = 0.0
237class TorchLossGaussianNLLLoss(TorchBaseConfig, TEps, TReduct):
238 full: bool = False
241class TorchLossHingeEmbeddingLoss(TorchBaseConfig, TGroup3):
242 pass
245class TorchLossHuberLoss(TorchBaseConfig, TReduct):
246 delta: float = 1.0
249class TorchLossKLDivLoss(TorchBaseConfig, TGroup2):
250 log_target: bool = False
253class TorchLossL1Loss(TorchBaseConfig, TGroup2):
254 pass
257class TorchLossMSELoss(TorchBaseConfig, TGroup2):
258 pass
261class TorchLossMarginRankingLoss(TorchBaseConfig, TGroup3):
262 pass
265class TorchLossMultiLabelMarginLoss(TorchBaseConfig, TGroup2):
266 pass
269class TorchLossMultiLabelSoftMarginLoss(TorchBaseConfig, TGroup2):
270 pass
273class TorchLossMultiMarginLoss(TorchBaseConfig, TGroup3):
274 exponent: int = 1 # convert to p
277class TorchLossNLLLoss(TorchBaseConfig, TGroup2):
278 ignore_index: int = -100
281class TorchLossPairwiseDistance(TorchBaseConfig, TEps):
282 keepdim: bool = False
283 norm_degree: float = 2.0 # convert to p
286class TorchLossPoissonNLLLoss(TorchBaseConfig, TEps, TGroup2):
287 full: bool = False
288 log_input: bool = True
291class TorchLossSmoothL1Loss(TorchBaseConfig, TGroup2):
292 beta: float = 1.0
295class TorchLossSoftMarginLoss(TorchBaseConfig, TGroup2):
296 pass
299class TorchLossTripletMarginLoss(TorchBaseConfig, TEps, TGroup3):
300 norm_degree: float = 2.0 # convert to p
301 swap: bool = False
304class TorchLossTripletMarginWithDistanceLoss(TorchBaseConfig, TMarg, TReduct):
305 swap: bool = False
308# METRICS-----------------------------------------------------------------------
309class TorchMetricBLEUScore(TorchBaseConfig):
310 n_gram: int = 4
311 smooth: bool = False
312 weights: OptListFloat = None
315class TorchMetricCHRFScore(TorchBaseConfig):
316 beta: float = 2.0
317 lowercase: bool = False
318 n_char_order: int = 6
319 n_word_order: int = 2
320 return_sentence_level_score: bool = False
321 whitespace: bool = False
324class TorchMetricCatMetric(TorchBaseConfig, TNan):
325 pass
328class TorchMetricConcordanceCorrCoef(TorchBaseConfig, TOut):
329 pass
332class TorchMetricCosineSimilarity(TorchBaseConfig):
333 reduction: ReductionType = 'sum'
336class TorchMetricCramersV(TorchBaseConfig, TCls, TNanStrategy):
337 bias_correction: bool = True
338 nan_replace_value: OptFloat = 0.0
341class TorchMetricCriticalSuccessIndex(TorchBaseConfig):
342 keep_sequence_dim: OptInt = None
343 threshold: float
346class TorchMetricDice(TorchBaseConfig, TCls, TInd, TTopK):
347 average: Optional[Annotated[
348 str, pyd.Field(pattern='^(micro|macro|weighted|samples|none)$')
349 ]] = 'micro'
350 mdmc_average: Optional[Annotated[
351 str, pyd.Field(pattern='^(samplewise|global)$')
352 ]] = 'global'
353 multiclass: OptBool = None
354 threshold: float = 0.5
355 zero_division: int = 0
358class TorchMetricErrorRelativeGlobalDimensionlessSynthesis(TorchBaseConfig, TMReduct):
359 ratio: float = 4
362class TorchMetricExplainedVariance(TorchBaseConfig):
363 multioutput: Annotated[
364 str, pyd.Field(pattern='^(raw_values|uniform_average|variance_weighted)$')
365 ] = 'uniform_average'
368class TorchMetricExtendedEditDistance(TorchBaseConfig):
369 alpha: float = 2.0
370 deletion: float = 0.2
371 insertion: float = 1.0
372 language: Annotated[str, pyd.Field(pattern='^(en|ja)$')] = 'en'
373 return_sentence_level_score: bool = False
374 rho: float = 0.3
377class TorchMetricFleissKappa(TorchBaseConfig):
378 mode: Annotated[str, pyd.Field(pattern='^(counts|probs)$')] = 'counts'
381class TorchMetricKLDivergence(TorchBaseConfig):
382 log_prob: bool = False
383 reduction: ReductionType = 'mean'
386class TorchMetricKendallRankCorrCoef(TorchBaseConfig, TOut):
387 alternative: Optional[Annotated[
388 str, pyd.Field(pattern='^(two-sided|less|greater)$')
389 ]] = 'two-sided'
390 t_test: bool = False
391 variant: Annotated[str, pyd.Field(pattern='^(a|b|c)$')] = 'b'
394class TorchMetricLogCoshError(TorchBaseConfig, TOut):
395 pass
398class TorchMetricMaxMetric(TorchBaseConfig, TNan):
399 pass
402class TorchMetricMeanAbsoluteError(TorchBaseConfig, TOut):
403 pass
406class TorchMetricMeanMetric(TorchBaseConfig, TNan):
407 pass
410class TorchMetricMeanSquaredError(TorchBaseConfig, TOut):
411 squared: bool = True
414class TorchMetricMinMetric(TorchBaseConfig, TNan):
415 pass
418class TorchMetricMinkowskiDistance(TorchBaseConfig):
419 p: float
422class TorchMetricModifiedPanopticQuality(TorchBaseConfig):
423 allow_unknown_preds_category: bool = False
424 stuffs: list[int]
425 things: list[int]
428class TorchMetricMultiScaleStructuralSimilarityIndexMeasure(TorchBaseConfig, TMReduct, TDate):
429 betas: tuple = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
430 gaussian_kernel: bool = True
431 k1: float = 0.01
432 k2: float = 0.03
433 kernel_size: Ints = 11
434 normalize: Optional[Annotated[
435 str, pyd.Field(pattern='^(relu|simple)$')
436 ]] = 'relu'
437 sigma: Floats = 1.5
440class TorchMetricNormalizedRootMeanSquaredError(TorchBaseConfig, TOut):
441 normalization: Annotated[
442 str, pyd.Field(pattern='^(mean|range|std|l2)$')
443 ] = 'mean'
446class TorchMetricPanopticQuality(TorchBaseConfig):
447 allow_unknown_preds_category: bool = False
448 stuffs: list[int]
449 things: list[int]
452class TorchMetricPeakSignalNoiseRatio(TorchBaseConfig, TMReduct, TDate):
453 base: float = 10.0
454 dim: OptInts = None
457class TorchMetricPearsonCorrCoef(TorchBaseConfig, TOut):
458 pass
461class TorchMetricPearsonsContingencyCoefficient(TorchBaseConfig, TNanStrategy):
462 nan_replace_value: OptFloat = 0.0
463 num_classes: int
466class TorchMetricPermutationInvariantTraining(TorchBaseConfig):
467 eval_func: Annotated[str, pyd.Field(pattern='^(max|min)$')] = 'max'
468 mode: Annotated[
469 str, pyd.Field(pattern='^(speaker-wise|permutation-wise)$')
470 ] = 'speaker-wise'
473class TorchMetricPerplexity(TorchBaseConfig, TInd):
474 pass
477class TorchMetricR2Score(TorchBaseConfig):
478 adjusted: int = 0
479 multioutput: Annotated[
480 str, pyd.Field(pattern='^(raw_values|uniform_average|variance_weighted)$')
481 ] = 'uniform_average'
484class TorchMetricRelativeAverageSpectralError(TorchBaseConfig):
485 window_size: int = 8
488class TorchMetricRelativeSquaredError(TorchBaseConfig, TOut):
489 squared: bool = True
492class TorchMetricRetrievalFallOut(TorchBaseConfig, TInd, TTopK):
493 empty_target_action: Annotated[
494 str, pyd.Field(pattern='^(error|skip|neg|pos)$')
495 ] = 'pos'
498class TorchMetricRetrievalHitRate(TorchBaseConfig, TAct, TInd, TTopK):
499 pass
502class TorchMetricRetrievalMAP(TorchBaseConfig, TAct, TInd, TTopK):
503 pass
506class TorchMetricRetrievalMRR(TorchBaseConfig, TAct, TInd):
507 pass
510class TorchMetricRetrievalNormalizedDCG(TorchBaseConfig, TAct, TInd, TTopK):
511 pass
514class TorchMetricRetrievalPrecision(TorchBaseConfig, TAct, TInd, TTopK):
515 adaptive_k: bool = False
518class TorchMetricRetrievalPrecisionRecallCurve(TorchBaseConfig, TInd):
519 adaptive_k: bool = False
520 max_k: OptInt = None
523class TorchMetricRetrievalRPrecision(TorchBaseConfig, TAct, TInd):
524 pass
527class TorchMetricRetrievalRecall(TorchBaseConfig, TAct, TInd, TTopK):
528 pass
531class TorchMetricRetrievalRecallAtFixedPrecision(TorchBaseConfig, TAct, TInd):
532 adaptive_k: bool = False
533 max_k: OptInt = None
534 min_precision: float = 0.0
537class TorchMetricRootMeanSquaredErrorUsingSlidingWindow(TorchBaseConfig):
538 window_size: int = 8
541class TorchMetricRunningMean(TorchBaseConfig, TNan):
542 window: int = 5
545class TorchMetricRunningSum(TorchBaseConfig, TNan):
546 window: int = 5
549class TorchMetricSacreBLEUScore(TorchBaseConfig):
550 lowercase: bool = False
551 n_gram: int = 4
552 smooth: bool = False
553 tokenize: Annotated[str, pyd.Field(pattern=_TOKENIZE_RE)] = '13a'
554 weights: OptListFloat = None
557class TorchMetricScaleInvariantSignalDistortionRatio(TorchBaseConfig):
558 zero_mean: bool = False
561class TorchMetricSignalDistortionRatio(TorchBaseConfig):
562 filter_length: int = 512
563 load_diag: OptFloat = None
564 use_cg_iter: OptInt = None
565 zero_mean: bool = False
568class TorchMetricSignalNoiseRatio(TorchBaseConfig):
569 zero_mean: bool = False
572class TorchMetricSpearmanCorrCoef(TorchBaseConfig, TOut):
573 pass
576class TorchMetricSpectralAngleMapper(TorchBaseConfig, TMReduct):
577 pass
580class TorchMetricSpectralDistortionIndex(TorchBaseConfig, TMReduct):
581 p: int = 1
584class TorchMetricStructuralSimilarityIndexMeasure(TorchBaseConfig, TMReduct):
585 data_range: OptPairFloat = None
586 gaussian_kernel: bool = True
587 k1: float = 0.01
588 k2: float = 0.03
589 kernel_size: Ints = 11
590 return_contrast_sensitivity: bool = False
591 return_full_image: bool = False
592 sigma: Floats = 1.5
595class TorchMetricSumMetric(TorchBaseConfig, TNan):
596 pass
599class TorchMetricTheilsU(TorchBaseConfig, TNanStrategy):
600 nan_replace_value: OptFloat = 0.0
601 num_classes: int
604class TorchMetricTotalVariation(TorchBaseConfig):
605 reduction: ReductionType = 'sum'
608class TorchMetricTranslationEditRate(TorchBaseConfig):
609 asian_support: bool = False
610 lowercase: bool = True
611 no_punctuation: bool = False
612 normalize: bool = False
613 return_sentence_level_score: bool = False
616class TorchMetricTschuprowsT(TorchBaseConfig, TNanStrategy):
617 bias_correction: bool = True
618 nan_replace_value: OptFloat = 0.0
619 num_classes: int
622class TorchMetricTweedieDevianceScore(TorchBaseConfig):
623 power: float = 0.0
626class TorchMetricUniversalImageQualityIndex(TorchBaseConfig, TMReduct):
627 kernel_size: tuple[int, ...] = (11, 11)
628 sigma: tuple[float, ...] = (1.5, 1.5)