Coverage for /home/ubuntu/flatiron/python/flatiron/tf/config.py: 100%

222 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Optional 

2from typing_extensions import Annotated 

3from flatiron.core.types import OptInt, OptFloat, OptStr, OptFloats, OptListFloat 

4 

5import pydantic as pyd 

6# ------------------------------------------------------------------------------ 

7 

8 

9_DTYPE_RE = '^(bfloat16|bool|complex128|complex64|double' 

10_DTYPE_RE += '|float(16|32|64)|half|int(8|16|32|64)|qint(8|16|32)|quint(8|16)' 

11_DTYPE_RE += '|resource|string|uint(8|16|32|64)|variant)$' 

12DType = Optional[Annotated[str, pyd.Field(pattern=_DTYPE_RE)]] 

13 

14 

15class TFBaseConfig(pyd.BaseModel): 

16 name: str 

17 

18 

19# FRAMEWORK--------------------------------------------------------------------- 

20class TFFramework(pyd.BaseModel): 

21 ''' 

22 Configuration for calls to model.compile. 

23 

24 See: https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile 

25 

26 Attributes: 

27 name (str): Framework name. Default: 'tensorflow'. 

28 device (str, optional): Hardware device. Default: 'gpu'. 

29 loss_weights (list[float], optional): List of loss weights. 

30 Default: None. 

31 weighted_metrics (list[float], optional): List of metric weights. 

32 Default: None. 

33 run_eagerly (bool, optional): Leave as False. Default: False. 

34 steps_per_execution (int, optional): Number of batches per function 

35 call. Default: 1. 

36 jit_compile (bool, optional): Use XLA. Default: False. 

37 auto_scale_loss (bool, optional): Model dtype is mixed_float16 when 

38 True. Default: True. 

39 ''' 

40 name: str = 'tensorflow' 

41 device: Annotated[str, pyd.Field(pattern='^(cpu|gpu)$')] = 'gpu' 

42 loss_weights: OptListFloat = None 

43 weighted_metrics: OptListFloat = None 

44 run_eagerly: bool = False 

45 steps_per_execution: Annotated[int, pyd.Field(gt=0)] = 1 

46 jit_compile: bool = False 

47 auto_scale_loss: bool = True 

48 

49 

50# OPTIMIZER-HELPERS------------------------------------------------------------- 

51class TFOptBaseConfig(TFBaseConfig): 

52 clipnorm: OptFloat = None 

53 clipvalue: OptFloat = None 

54 ema_momentum: Annotated[float, pyd.Field(ge=0)] = 0.99 

55 ema_overwrite_frequency: Optional[Annotated[int, pyd.Field(ge=0)]] = None 

56 global_clipnorm: OptFloat = None 

57 gradient_accumulation_steps: Optional[Annotated[int, pyd.Field(gt=0)]] = None 

58 learning_rate: Optional[Annotated[float, pyd.Field(gt=0)]] = 0.001 

59 loss_scale_factor: OptFloat = None 

60 use_ema: bool = False 

61 # weight_decay: OptFloat = None # deprecated 

62 

63 

64class TFEpsilon(pyd.BaseModel): 

65 epsilon: Annotated[float, pyd.Field(gt=0)] = 1e-07 

66 

67 

68class TFBeta(pyd.BaseModel): 

69 beta_1: float = 0.9 

70 beta_2: float = 0.99 

71 

72 

73# LOSS-HELPERS------------------------------------------------------------------ 

74class TFLossBaseConfig(TFBaseConfig): 

75 dtype: DType = None 

76 reduction: Annotated[ 

77 str, pyd.Field(pattern='^(auto|none|sum|sum_over_batch_size)$') 

78 ] = 'sum_over_batch_size' 

79 

80 

81class TFAxis(pyd.BaseModel): 

82 axis: int = -1 

83 

84 

85class TFLogits(pyd.BaseModel): 

86 from_logits: bool = False 

87 

88 

89# METRIC-HELPERS---------------------------------------------------------------- 

90class TFMetricBaseConfig(TFBaseConfig): 

91 dtype: DType = None 

92 

93 

94class TFThresh(pyd.BaseModel): 

95 thresholds: OptFloats = None 

96 

97 

98class TFClsId(pyd.BaseModel): 

99 class_id: OptInt = None 

100 

101 

102class TFNumThresh(pyd.BaseModel): 

103 num_thresholds: Annotated[int, pyd.Field(gt=1)] = 200 

104 

105 

106class TFNumClasses(pyd.BaseModel): 

107 num_classes: Annotated[int, pyd.Field(ge=0)] 

108 

109 

110class TFIgnoreClass(pyd.BaseModel): 

111 ignore_class: OptInt = None 

112 

113 

114# OPTIMIZERS-------------------------------------------------------------------- 

115class TFOptAdafactor(TFOptBaseConfig): 

116 beta_2_decay: float = -0.8 

117 clip_threshold: float = 1.0 

118 epsilon_1: Annotated[float, pyd.Field(gt=0)] = 1e-30 

119 epsilon_2: Annotated[float, pyd.Field(gt=0)] = 0.001 

120 relative_step: bool = True 

121 

122 

123class TFOptFtrl(TFOptBaseConfig): 

124 beta: float = 0.0 

125 initial_accumulator_value: float = 0.1 

126 l1_regularization_strength: float = 0.0 

127 l2_regularization_strength: float = 0.0 

128 l2_shrinkage_regularization_strength: float = 0.0 

129 learning_rate_power: Annotated[float, pyd.Field(le=0)] = -0.5 

130 

131 

132class TFOptLion(TFOptBaseConfig, TFBeta): 

133 pass 

134 

135 

136class TFOptSGD(TFOptBaseConfig): 

137 momentum: Annotated[float, pyd.Field(ge=0)] = 0.0 

138 nesterov: bool = False 

139 

140 

141class TFOptAdadelta(TFOptBaseConfig, TFEpsilon): 

142 rho: float = 0.95 

143 

144 

145class TFOptAdagrad(TFOptBaseConfig, TFEpsilon): 

146 initial_accumulator_value: float = 0.1 

147 

148 

149class TFOptAdam(TFOptBaseConfig, TFBeta, TFEpsilon): 

150 amsgrad: bool = False 

151 

152 

153class TFOptAdamW(TFOptBaseConfig, TFBeta, TFEpsilon): 

154 amsgrad: bool = False 

155 weight_decay: float = 0.004 

156 

157 

158class TFOptAdamax(TFOptBaseConfig, TFBeta, TFEpsilon): 

159 pass 

160 

161 

162class TFOptLamb(TFOptBaseConfig, TFBeta, TFEpsilon): 

163 pass 

164 

165 

166class TFOptNadam(TFOptBaseConfig, TFBeta, TFEpsilon): 

167 pass 

168 

169 

170class TFOptRMSprop(TFOptBaseConfig, TFEpsilon): 

171 centered: bool = False 

172 momentum: float = 0.0 

173 rho: float = 0.9 

174 

175 

176# LOSSES------------------------------------------------------------------------ 

177class TFLossBinaryCrossentropy(TFLossBaseConfig, TFAxis, TFLogits): 

178 label_smoothing: float = 0.0 

179 

180 

181class TFLossBinaryFocalCrossentropy(TFLossBaseConfig, TFAxis, TFLogits): 

182 alpha: float = 0.25 

183 apply_class_balancing: bool = False 

184 gamma: float = 2.0 

185 label_smoothing: float = 0.0 

186 

187 

188class TFLossCategoricalCrossentropy(TFLossBaseConfig, TFAxis, TFLogits): 

189 label_smoothing: float = 0.0 

190 

191 

192class TFLossCategoricalFocalCrossentropy(TFLossBaseConfig, TFAxis, TFLogits): 

193 alpha: float = 0.25 

194 gamma: float = 2.0 

195 label_smoothing: float = 0.0 

196 

197 

198class TFLossCircle(TFLossBaseConfig): 

199 gamma: float = 80.0 

200 margin: float = 0.4 

201 remove_diagonal: bool = True 

202 

203 

204class TFLossCosineSimilarity(TFLossBaseConfig, TFAxis): 

205 pass 

206 

207 

208class TFLossDice(TFLossBaseConfig, TFAxis): 

209 pass 

210 

211 

212class TFLossHuber(TFLossBaseConfig): 

213 delta: float = 1.0 

214 

215 

216class TFLossSparseCategoricalCrossentropy(TFLossBaseConfig, TFLogits, TFIgnoreClass): 

217 pass 

218 

219 

220class TFLossTversky(TFLossBaseConfig, TFAxis): 

221 alpha: float = 0.5 

222 beta: float = 0.5 

223 

224 

225# METRICS----------------------------------------------------------------------- 

226class TFMetricAUC(TFMetricBaseConfig, TFLogits, TFNumThresh, TFThresh): 

227 curve: Annotated[str, pyd.Field(pattern='^(ROC|PR)$')] = 'ROC' 

228 label_weights: OptListFloat = None 

229 multi_label: bool = False 

230 num_labels: OptInt = None 

231 summation_method: Annotated[ 

232 str, pyd.Field(pattern='^(interpolation|minoring|majoring)$') 

233 ] = 'interpolation' 

234 

235 

236class TFMetricAccuracy(TFMetricBaseConfig): 

237 pass 

238 

239 

240class TFMetricBinaryAccuracy(TFMetricBaseConfig): 

241 threshold: float = 0.5 

242 

243 

244class TFMetricBinaryCrossentropy(TFMetricBaseConfig, TFLogits): 

245 label_smoothing: int = 0 

246 

247 

248class TFMetricBinaryIoU(TFMetricBaseConfig): 

249 target_class_ids: list[int] = [0, 1] 

250 threshold: float = 0.5 

251 

252 

253class TFMetricCategoricalAccuracy(TFMetricBaseConfig): 

254 pass 

255 

256 

257class TFMetricCategoricalCrossentropy(TFMetricBaseConfig, TFAxis, TFLogits): 

258 label_smoothing: int = 0 

259 

260 

261class TFMetricCategoricalHinge(TFMetricBaseConfig): 

262 pass 

263 

264 

265class TFMetricConcordanceCorrelation(TFMetricBaseConfig, TFAxis): 

266 pass 

267 

268 

269class TFMetricCosineSimilarity(TFMetricBaseConfig, TFAxis): 

270 pass 

271 

272 

273class TFMetricF1Score(TFMetricBaseConfig): 

274 average: OptStr = None 

275 threshold: OptFloat = None 

276 

277 

278class TFMetricFBetaScore(TFMetricBaseConfig): 

279 average: OptStr = None 

280 beta: float = 1.0 

281 threshold: OptFloat = None 

282 

283 

284class TFMetricFalseNegatives(TFMetricBaseConfig, TFThresh): 

285 pass 

286 

287 

288class TFMetricFalsePositives(TFMetricBaseConfig, TFThresh): 

289 pass 

290 

291 

292class TFMetricHinge(TFMetricBaseConfig): 

293 pass 

294 

295 

296class TFMetricIoU(TFMetricBaseConfig, TFAxis, TFIgnoreClass, TFNumClasses): 

297 sparse_y_pred: bool = True 

298 sparse_y_true: bool = True 

299 target_class_ids: list[int] 

300 

301 

302class TFMetricKLDivergence(TFMetricBaseConfig): 

303 pass 

304 

305 

306class TFMetricLogCoshError(TFMetricBaseConfig): 

307 pass 

308 

309 

310class TFMetricMean(TFMetricBaseConfig): 

311 pass 

312 

313 

314class TFMetricMeanAbsoluteError(TFMetricBaseConfig): 

315 pass 

316 

317 

318class TFMetricMeanAbsolutePercentageError(TFMetricBaseConfig): 

319 pass 

320 

321 

322class TFMetricMeanIoU(TFMetricBaseConfig, TFAxis, TFIgnoreClass, TFNumClasses): 

323 sparse_y_pred: bool = True 

324 sparse_y_true: bool = True 

325 

326 

327class TFMetricMeanSquaredError(TFMetricBaseConfig): 

328 pass 

329 

330 

331class TFMetricMeanSquaredLogarithmicError(TFMetricBaseConfig): 

332 pass 

333 

334 

335class TFMetricMetric(TFMetricBaseConfig): 

336 pass 

337 

338 

339class TFMetricOneHotIoU(TFMetricBaseConfig, TFAxis, TFIgnoreClass, TFNumClasses): 

340 sparse_y_pred: bool = False 

341 target_class_ids: list[int] 

342 

343 

344class TFMetricOneHotMeanIoU(TFMetricBaseConfig, TFAxis, TFIgnoreClass, TFNumClasses): 

345 sparse_y_pred: bool = False 

346 

347 

348class TFMetricPearsonCorrelation(TFMetricBaseConfig, TFAxis): 

349 pass 

350 

351 

352class TFMetricPoisson(TFMetricBaseConfig): 

353 pass 

354 

355 

356class TFMetricPrecision(TFMetricBaseConfig, TFClsId, TFThresh): 

357 top_k: OptInt = None 

358 

359 

360class TFMetricPrecisionAtRecall(TFMetricBaseConfig, TFClsId, TFNumThresh): 

361 recall: float 

362 

363 

364class TFMetricR2Score(TFMetricBaseConfig): 

365 class_aggregation: Optional[Annotated[ 

366 str, pyd.Field(pattern='^(uniform_average|variance_weighted_average)$') 

367 ]] = 'uniform_average' 

368 num_regressors: Annotated[int, pyd.Field(ge=0)] = 0 

369 

370 

371class TFMetricRecall(TFMetricBaseConfig, TFClsId, TFThresh): 

372 top_k: OptInt = None 

373 

374 

375class TFMetricRecallAtPrecision(TFMetricBaseConfig, TFClsId, TFNumThresh): 

376 precision: float 

377 

378 

379class TFMetricRootMeanSquaredError(TFMetricBaseConfig): 

380 pass 

381 

382 

383class TFMetricSensitivityAtSpecificity(TFMetricBaseConfig, TFClsId, TFNumThresh): 

384 specificity: float 

385 

386 

387class TFMetricSparseCategoricalAccuracy(TFMetricBaseConfig): 

388 pass 

389 

390 

391class TFMetricSparseCategoricalCrossentropy(TFMetricBaseConfig, TFAxis, TFLogits): 

392 pass 

393 

394 

395class TFMetricSparseTopKCategoricalAccuracy(TFMetricBaseConfig): 

396 from_sorted_ids: bool = False 

397 k: int = 5 

398 

399 

400class TFMetricSpecificityAtSensitivity(TFMetricBaseConfig, TFClsId, TFNumThresh): 

401 sensitivity: float 

402 

403 

404class TFMetricSquaredHinge(TFMetricBaseConfig): 

405 pass 

406 

407 

408class TFMetricSum(TFMetricBaseConfig): 

409 pass 

410 

411 

412class TFMetricTopKCategoricalAccuracy(TFMetricBaseConfig): 

413 k: int = 5 

414 

415 

416class TFMetricTrueNegatives(TFMetricBaseConfig, TFThresh): 

417 pass 

418 

419 

420class TFMetricTruePositives(TFMetricBaseConfig, TFThresh): 

421 pass