Coverage for /home/ubuntu/flatiron/python/flatiron/tf/config.py: 100%
222 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Optional
2from typing_extensions import Annotated
3from flatiron.core.types import OptInt, OptFloat, OptStr, OptFloats, OptListFloat
5import pydantic as pyd
6# ------------------------------------------------------------------------------
9_DTYPE_RE = '^(bfloat16|bool|complex128|complex64|double'
10_DTYPE_RE += '|float(16|32|64)|half|int(8|16|32|64)|qint(8|16|32)|quint(8|16)'
11_DTYPE_RE += '|resource|string|uint(8|16|32|64)|variant)$'
12DType = Optional[Annotated[str, pyd.Field(pattern=_DTYPE_RE)]]
15class TFBaseConfig(pyd.BaseModel):
16 name: str
19# FRAMEWORK---------------------------------------------------------------------
20class TFFramework(pyd.BaseModel):
21 '''
22 Configuration for calls to model.compile.
24 See: https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile
26 Attributes:
27 name (str): Framework name. Default: 'tensorflow'.
28 device (str, optional): Hardware device. Default: 'gpu'.
29 loss_weights (list[float], optional): List of loss weights.
30 Default: None.
31 weighted_metrics (list[float], optional): List of metric weights.
32 Default: None.
33 run_eagerly (bool, optional): Leave as False. Default: False.
34 steps_per_execution (int, optional): Number of batches per function
35 call. Default: 1.
36 jit_compile (bool, optional): Use XLA. Default: False.
37 auto_scale_loss (bool, optional): Model dtype is mixed_float16 when
38 True. Default: True.
39 '''
40 name: str = 'tensorflow'
41 device: Annotated[str, pyd.Field(pattern='^(cpu|gpu)$')] = 'gpu'
42 loss_weights: OptListFloat = None
43 weighted_metrics: OptListFloat = None
44 run_eagerly: bool = False
45 steps_per_execution: Annotated[int, pyd.Field(gt=0)] = 1
46 jit_compile: bool = False
47 auto_scale_loss: bool = True
50# OPTIMIZER-HELPERS-------------------------------------------------------------
51class TFOptBaseConfig(TFBaseConfig):
52 clipnorm: OptFloat = None
53 clipvalue: OptFloat = None
54 ema_momentum: Annotated[float, pyd.Field(ge=0)] = 0.99
55 ema_overwrite_frequency: Optional[Annotated[int, pyd.Field(ge=0)]] = None
56 global_clipnorm: OptFloat = None
57 gradient_accumulation_steps: Optional[Annotated[int, pyd.Field(gt=0)]] = None
58 learning_rate: Optional[Annotated[float, pyd.Field(gt=0)]] = 0.001
59 loss_scale_factor: OptFloat = None
60 use_ema: bool = False
61 # weight_decay: OptFloat = None # deprecated
64class TFEpsilon(pyd.BaseModel):
65 epsilon: Annotated[float, pyd.Field(gt=0)] = 1e-07
68class TFBeta(pyd.BaseModel):
69 beta_1: float = 0.9
70 beta_2: float = 0.99
73# LOSS-HELPERS------------------------------------------------------------------
74class TFLossBaseConfig(TFBaseConfig):
75 dtype: DType = None
76 reduction: Annotated[
77 str, pyd.Field(pattern='^(auto|none|sum|sum_over_batch_size)$')
78 ] = 'sum_over_batch_size'
81class TFAxis(pyd.BaseModel):
82 axis: int = -1
85class TFLogits(pyd.BaseModel):
86 from_logits: bool = False
89# METRIC-HELPERS----------------------------------------------------------------
90class TFMetricBaseConfig(TFBaseConfig):
91 dtype: DType = None
94class TFThresh(pyd.BaseModel):
95 thresholds: OptFloats = None
98class TFClsId(pyd.BaseModel):
99 class_id: OptInt = None
102class TFNumThresh(pyd.BaseModel):
103 num_thresholds: Annotated[int, pyd.Field(gt=1)] = 200
106class TFNumClasses(pyd.BaseModel):
107 num_classes: Annotated[int, pyd.Field(ge=0)]
110class TFIgnoreClass(pyd.BaseModel):
111 ignore_class: OptInt = None
114# OPTIMIZERS--------------------------------------------------------------------
115class TFOptAdafactor(TFOptBaseConfig):
116 beta_2_decay: float = -0.8
117 clip_threshold: float = 1.0
118 epsilon_1: Annotated[float, pyd.Field(gt=0)] = 1e-30
119 epsilon_2: Annotated[float, pyd.Field(gt=0)] = 0.001
120 relative_step: bool = True
123class TFOptFtrl(TFOptBaseConfig):
124 beta: float = 0.0
125 initial_accumulator_value: float = 0.1
126 l1_regularization_strength: float = 0.0
127 l2_regularization_strength: float = 0.0
128 l2_shrinkage_regularization_strength: float = 0.0
129 learning_rate_power: Annotated[float, pyd.Field(le=0)] = -0.5
132class TFOptLion(TFOptBaseConfig, TFBeta):
133 pass
136class TFOptSGD(TFOptBaseConfig):
137 momentum: Annotated[float, pyd.Field(ge=0)] = 0.0
138 nesterov: bool = False
141class TFOptAdadelta(TFOptBaseConfig, TFEpsilon):
142 rho: float = 0.95
145class TFOptAdagrad(TFOptBaseConfig, TFEpsilon):
146 initial_accumulator_value: float = 0.1
149class TFOptAdam(TFOptBaseConfig, TFBeta, TFEpsilon):
150 amsgrad: bool = False
153class TFOptAdamW(TFOptBaseConfig, TFBeta, TFEpsilon):
154 amsgrad: bool = False
155 weight_decay: float = 0.004
158class TFOptAdamax(TFOptBaseConfig, TFBeta, TFEpsilon):
159 pass
162class TFOptLamb(TFOptBaseConfig, TFBeta, TFEpsilon):
163 pass
166class TFOptNadam(TFOptBaseConfig, TFBeta, TFEpsilon):
167 pass
170class TFOptRMSprop(TFOptBaseConfig, TFEpsilon):
171 centered: bool = False
172 momentum: float = 0.0
173 rho: float = 0.9
176# LOSSES------------------------------------------------------------------------
177class TFLossBinaryCrossentropy(TFLossBaseConfig, TFAxis, TFLogits):
178 label_smoothing: float = 0.0
181class TFLossBinaryFocalCrossentropy(TFLossBaseConfig, TFAxis, TFLogits):
182 alpha: float = 0.25
183 apply_class_balancing: bool = False
184 gamma: float = 2.0
185 label_smoothing: float = 0.0
188class TFLossCategoricalCrossentropy(TFLossBaseConfig, TFAxis, TFLogits):
189 label_smoothing: float = 0.0
192class TFLossCategoricalFocalCrossentropy(TFLossBaseConfig, TFAxis, TFLogits):
193 alpha: float = 0.25
194 gamma: float = 2.0
195 label_smoothing: float = 0.0
198class TFLossCircle(TFLossBaseConfig):
199 gamma: float = 80.0
200 margin: float = 0.4
201 remove_diagonal: bool = True
204class TFLossCosineSimilarity(TFLossBaseConfig, TFAxis):
205 pass
208class TFLossDice(TFLossBaseConfig, TFAxis):
209 pass
212class TFLossHuber(TFLossBaseConfig):
213 delta: float = 1.0
216class TFLossSparseCategoricalCrossentropy(TFLossBaseConfig, TFLogits, TFIgnoreClass):
217 pass
220class TFLossTversky(TFLossBaseConfig, TFAxis):
221 alpha: float = 0.5
222 beta: float = 0.5
225# METRICS-----------------------------------------------------------------------
226class TFMetricAUC(TFMetricBaseConfig, TFLogits, TFNumThresh, TFThresh):
227 curve: Annotated[str, pyd.Field(pattern='^(ROC|PR)$')] = 'ROC'
228 label_weights: OptListFloat = None
229 multi_label: bool = False
230 num_labels: OptInt = None
231 summation_method: Annotated[
232 str, pyd.Field(pattern='^(interpolation|minoring|majoring)$')
233 ] = 'interpolation'
236class TFMetricAccuracy(TFMetricBaseConfig):
237 pass
240class TFMetricBinaryAccuracy(TFMetricBaseConfig):
241 threshold: float = 0.5
244class TFMetricBinaryCrossentropy(TFMetricBaseConfig, TFLogits):
245 label_smoothing: int = 0
248class TFMetricBinaryIoU(TFMetricBaseConfig):
249 target_class_ids: list[int] = [0, 1]
250 threshold: float = 0.5
253class TFMetricCategoricalAccuracy(TFMetricBaseConfig):
254 pass
257class TFMetricCategoricalCrossentropy(TFMetricBaseConfig, TFAxis, TFLogits):
258 label_smoothing: int = 0
261class TFMetricCategoricalHinge(TFMetricBaseConfig):
262 pass
265class TFMetricConcordanceCorrelation(TFMetricBaseConfig, TFAxis):
266 pass
269class TFMetricCosineSimilarity(TFMetricBaseConfig, TFAxis):
270 pass
273class TFMetricF1Score(TFMetricBaseConfig):
274 average: OptStr = None
275 threshold: OptFloat = None
278class TFMetricFBetaScore(TFMetricBaseConfig):
279 average: OptStr = None
280 beta: float = 1.0
281 threshold: OptFloat = None
284class TFMetricFalseNegatives(TFMetricBaseConfig, TFThresh):
285 pass
288class TFMetricFalsePositives(TFMetricBaseConfig, TFThresh):
289 pass
292class TFMetricHinge(TFMetricBaseConfig):
293 pass
296class TFMetricIoU(TFMetricBaseConfig, TFAxis, TFIgnoreClass, TFNumClasses):
297 sparse_y_pred: bool = True
298 sparse_y_true: bool = True
299 target_class_ids: list[int]
302class TFMetricKLDivergence(TFMetricBaseConfig):
303 pass
306class TFMetricLogCoshError(TFMetricBaseConfig):
307 pass
310class TFMetricMean(TFMetricBaseConfig):
311 pass
314class TFMetricMeanAbsoluteError(TFMetricBaseConfig):
315 pass
318class TFMetricMeanAbsolutePercentageError(TFMetricBaseConfig):
319 pass
322class TFMetricMeanIoU(TFMetricBaseConfig, TFAxis, TFIgnoreClass, TFNumClasses):
323 sparse_y_pred: bool = True
324 sparse_y_true: bool = True
327class TFMetricMeanSquaredError(TFMetricBaseConfig):
328 pass
331class TFMetricMeanSquaredLogarithmicError(TFMetricBaseConfig):
332 pass
335class TFMetricMetric(TFMetricBaseConfig):
336 pass
339class TFMetricOneHotIoU(TFMetricBaseConfig, TFAxis, TFIgnoreClass, TFNumClasses):
340 sparse_y_pred: bool = False
341 target_class_ids: list[int]
344class TFMetricOneHotMeanIoU(TFMetricBaseConfig, TFAxis, TFIgnoreClass, TFNumClasses):
345 sparse_y_pred: bool = False
348class TFMetricPearsonCorrelation(TFMetricBaseConfig, TFAxis):
349 pass
352class TFMetricPoisson(TFMetricBaseConfig):
353 pass
356class TFMetricPrecision(TFMetricBaseConfig, TFClsId, TFThresh):
357 top_k: OptInt = None
360class TFMetricPrecisionAtRecall(TFMetricBaseConfig, TFClsId, TFNumThresh):
361 recall: float
364class TFMetricR2Score(TFMetricBaseConfig):
365 class_aggregation: Optional[Annotated[
366 str, pyd.Field(pattern='^(uniform_average|variance_weighted_average)$')
367 ]] = 'uniform_average'
368 num_regressors: Annotated[int, pyd.Field(ge=0)] = 0
371class TFMetricRecall(TFMetricBaseConfig, TFClsId, TFThresh):
372 top_k: OptInt = None
375class TFMetricRecallAtPrecision(TFMetricBaseConfig, TFClsId, TFNumThresh):
376 precision: float
379class TFMetricRootMeanSquaredError(TFMetricBaseConfig):
380 pass
383class TFMetricSensitivityAtSpecificity(TFMetricBaseConfig, TFClsId, TFNumThresh):
384 specificity: float
387class TFMetricSparseCategoricalAccuracy(TFMetricBaseConfig):
388 pass
391class TFMetricSparseCategoricalCrossentropy(TFMetricBaseConfig, TFAxis, TFLogits):
392 pass
395class TFMetricSparseTopKCategoricalAccuracy(TFMetricBaseConfig):
396 from_sorted_ids: bool = False
397 k: int = 5
400class TFMetricSpecificityAtSensitivity(TFMetricBaseConfig, TFClsId, TFNumThresh):
401 sensitivity: float
404class TFMetricSquaredHinge(TFMetricBaseConfig):
405 pass
408class TFMetricSum(TFMetricBaseConfig):
409 pass
412class TFMetricTopKCategoricalAccuracy(TFMetricBaseConfig):
413 k: int = 5
416class TFMetricTrueNegatives(TFMetricBaseConfig, TFThresh):
417 pass
420class TFMetricTruePositives(TFMetricBaseConfig, TFThresh):
421 pass