Coverage for /home/ubuntu/flatiron/python/flatiron/tf/tools.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Any, Optional # noqa F401 

2from flatiron.core.dataset import Dataset # noqa F401 

3from flatiron.core.types import Compiled, Filepath, Getter # noqa: F401 

4 

5from copy import deepcopy 

6import math 

7 

8from tensorflow import keras # noqa F401 

9from keras import callbacks as tfcallbacks 

10import tensorflow as tf 

11 

12import flatiron.core.tools as fict 

13import flatiron.tf.loss as fi_tfloss 

14import flatiron.tf.metric as fi_tfmetric 

15import flatiron.tf.optimizer as fi_tfoptim 

16 

17Callbacks = dict[str, tfcallbacks.TensorBoard | tfcallbacks.ModelCheckpoint] 

18# ------------------------------------------------------------------------------ 

19 

20 

21def get(config, module, fallback_module): 

22 # type: (Getter, str, str) -> Any 

23 ''' 

24 Given a config and set of modules return an instance or function. 

25 

26 Args: 

27 config (dict): Instance config. 

28 module (str): Always __name__. 

29 fallback_module (str): Fallback module, either a tf or torch module. 

30 

31 Raises: 

32 EnforceError: If config is not a dict with a name key. 

33 

34 Returns: 

35 object: Instance or function. 

36 ''' 

37 fict.enforce_getter(config) 

38 # -------------------------------------------------------------------------- 

39 

40 config = deepcopy(config) 

41 name = config.pop('name') 

42 try: 

43 return fict.get_module_function(name, module) 

44 except NotImplementedError: 

45 pass 

46 

47 try: 

48 mod = fict.get_module(fallback_module) 

49 return mod.get(dict(class_name=name, config=config)) 

50 except ValueError: 

51 pass 

52 

53 return fict.get_module_class(name, fallback_module)(**config) 

54 

55 

56def get_callbacks(log_directory, checkpoint_pattern, checkpoint_params={}): 

57 # type: (Filepath, str, dict) -> Callbacks 

58 ''' 

59 Create a list of callbacks for Tensoflow model. 

60 

61 Args: 

62 log_directory (str or Path): Tensorboard project log directory. 

63 checkpoint_pattern (str): Filepath pattern for checkpoint callback. 

64 checkpoint_params (dict, optional): Params to be passed to checkpoint 

65 callback. Default: {}. 

66 

67 Raises: 

68 EnforceError: If log directory does not exist. 

69 EnforeError: If checkpoint pattern does not contain '{epoch}'. 

70 

71 Returns: 

72 dict: dict with Tensorboard and ModelCheckpoint callbacks. 

73 ''' 

74 fict.enforce_callbacks(log_directory, checkpoint_pattern) 

75 return dict( 

76 tensorboard=tfcallbacks.TensorBoard( 

77 log_dir=log_directory, histogram_freq=1, update_freq=1 

78 ), 

79 checkpoint=tfcallbacks.ModelCheckpoint(checkpoint_pattern, **checkpoint_params), 

80 ) 

81 

82 

83def pre_build(device): 

84 # type: (str) -> None 

85 ''' 

86 Sets hardware device. 

87 

88 Args: 

89 device (str): Hardware device. 

90 ''' 

91 if device == 'cpu': 

92 tf.config.set_visible_devices([], 'GPU') 

93 

94 

95def compile(framework, model, optimizer, loss, metrics): 

96 # type: (Getter, Any, Getter, Getter, list[Getter]) -> Getter 

97 ''' 

98 Call `modile.compile` on given model with kwargs. 

99 

100 Args: 

101 framework (dict): Framework dict. 

102 model (Any): Model to be compiled. 

103 optimizer (dict): Optimizer settings. 

104 loss (dict): Loss to be compiled. 

105 metrics (list[dict]): Metrics function to be compiled. 

106 

107 Returns: 

108 dict: Dict of compiled objects. 

109 ''' 

110 framework.pop('name') 

111 framework.pop('device') 

112 model.compile( 

113 optimizer=fi_tfoptim.get(optimizer), 

114 loss=fi_tfloss.get(loss), 

115 metrics=[fi_tfmetric.get(m) for m in metrics], 

116 **framework, 

117 ) 

118 return dict(model=model) 

119 

120 

121def train( 

122 compiled, # type: Compiled 

123 callbacks, # type: Callbacks 

124 train_data, # type: Dataset 

125 test_data, # type: Optional[Dataset] 

126 params, # type: dict 

127): 

128 # type: (...) -> None 

129 ''' 

130 Train TensorFlow model. 

131 

132 Args: 

133 compiled (dict): Compiled objects. 

134 callbacks (dict): Dict of callbacks. 

135 train_data (Dataset): Training dataset. 

136 test_data (Dataset): Test dataset. 

137 params (dict): Training params. 

138 ''' 

139 batch_size = params['batch_size'] 

140 model = compiled['model'] 

141 x_train, y_train = train_data.xy_split() 

142 steps = math.ceil(x_train.shape[0] / batch_size) 

143 

144 val = None 

145 if test_data is not None: 

146 val = test_data.xy_split() 

147 

148 model.fit( 

149 x=x_train, 

150 y=y_train, 

151 callbacks=list(callbacks.values()), 

152 validation_data=val, 

153 steps_per_epoch=steps, 

154 batch_size=params.get('batch_size', None), 

155 epochs=params.get('epochs', 1), 

156 verbose=params.get('verbose', 'auto'), 

157 validation_split=params.get('validation_split', 0.0), 

158 shuffle=params.get('shuffle', True), 

159 initial_epoch=params.get('initial_epoch', 0), 

160 validation_freq=params.get('validation_freq', 1), 

161 # class_weight=train.get('class_weight', None), 

162 # sample_weight=train.get('sample_weight', None), 

163 # validation_steps=train.get('validation_steps', None), 

164 # validation_batch_size=train.get('validation_batch_size', None), 

165 )