Coverage for /home/ubuntu/flatiron/python/flatiron/core/config.py: 100%

70 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Optional, Union 

2from typing_extensions import Annotated 

3 

4from flatiron.core.types import OptLabels, OptFloat, Getter 

5 

6import pydantic as pyd 

7 

8import flatiron.core.validators as vd 

9# ------------------------------------------------------------------------------ 

10 

11 

12class BaseConfig(pyd.BaseModel): 

13 model_config = pyd.ConfigDict(extra='forbid') 

14 

15 

16class DatasetConfig(BaseConfig): 

17 ''' 

18 Configuration for Dataset. 

19 

20 See: https://thenewflesh.github.io/flatiron/core.html#module-flatiron.core.dataset 

21 

22 Attributes: 

23 source (str): Dataset directory or CSV filepath. 

24 ext_regex (str, optional): File extension pattern. 

25 Default: 'npy|exr|png|jpeg|jpg|tiff'. 

26 labels (object, optional): Label channels. Default: None. 

27 label_axis (int, optional): Label axis. Default: -1. 

28 test_size (float, optional): Test set size as a proportion. 

29 Default: 0.2. 

30 limit (str or int): Limit data by number of samples. 

31 Default: None. 

32 reshape (bool, optional): Reshape concatenated data to incorpate frames 

33 as the first dimension: (FRAME, ...). Analogous to the first 

34 dimension being batch. Default: True. 

35 shuffle (bool, optional): Randomize data before splitting. 

36 Default: True. 

37 seed (int, optional): Shuffle seed number. Default: None. 

38 ''' 

39 source: str 

40 ext_regex: str = 'npy|exr|png|jpeg|jpg|tiff' 

41 labels: OptLabels = None 

42 label_axis: int = -1 

43 test_size: Optional[Annotated[float, pyd.Field(ge=0)]] = 0.2 

44 limit: Optional[Annotated[int, pyd.Field(ge=0)]] = None 

45 reshape: bool = True 

46 shuffle: bool = True 

47 seed: Optional[int] = None 

48 

49 

50class FrameworkConfig(pyd.BaseModel): 

51 name: Annotated[str, pyd.AfterValidator(vd.is_engine)] = 'tensorflow' 

52 device: str = 'cpu' 

53 

54 

55class OptimizerConfig(pyd.BaseModel): 

56 ''' 

57 Configuration for optimizer. 

58 

59 See: https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer 

60 

61 Attributes: 

62 name (string, optional): Name of optimizer. Default='SGD'. 

63 ''' 

64 name: str = 'SGD' 

65 

66 

67class LossConfig(pyd.BaseModel): 

68 ''' 

69 Configuration for loss. 

70 

71 Attributes: 

72 name (string, optional): Name of loss. Default='MeanSquaredError'. 

73 ''' 

74 name: str = 'MeanSquaredError' 

75 

76 

77class CallbacksConfig(BaseConfig): 

78 ''' 

79 Configuration for callbacks. 

80 

81 See: https://thenewflesh.github.io/flatiron/core.html#module-flatiron.core.tools 

82 See: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint 

83 

84 Attributes: 

85 project (str): Name of project. 

86 root (str or Path): Tensorboard parent directory. Default: /mnt/storage. 

87 monitor (str, optional): Metric to monitor. Default: 'val_loss'. 

88 verbose (int, optional): Log callback actions. Default: 0. 

89 save_best_only (bool, optional): Save only best model. Default: False. 

90 mode (str, optional): Overwrite best model via 

91 `mode(old metric, new metric)`. Options: [auto, min, max]. 

92 Default: 'auto'. 

93 save_weights_only (bool, optional): Only save model weights. 

94 Default: False. 

95 save_freq (union, optional): Save after each epoch or N batches. 

96 Options: 'epoch' or int. Default: 'epoch'. 

97 initial_value_threshold (float, optional): Initial best value of metric. 

98 Default: None. 

99 ''' 

100 project: str 

101 root: str 

102 monitor: str = 'val_loss' 

103 verbose: int = 0 

104 save_best_only: bool = False 

105 save_weights_only: bool = False 

106 mode: Annotated[str, pyd.AfterValidator(vd.is_callback_mode)] = 'auto' 

107 save_freq: Union[str, int] = 'epoch' 

108 initial_value_threshold: OptFloat = None 

109 

110 

111class TrainConfig(BaseConfig): 

112 ''' 

113 Configuration for calls to model train function. 

114 

115 See: https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit 

116 

117 Attributes: 

118 batch_size (int, optional): Number of samples per update. Default: 32. 

119 epochs (int, optional): Number of epochs to train model. Default: 30. 

120 verbose (str or int, optional): Verbosity of model logging. 

121 Options: 'auto', 0, 1, 2. 

122 0 is silent. 1 is progress bar. 2 is one line per epoch. 

123 Auto is usually 1. Default: auto. 

124 validation_split (float, optional): Fraction of training data to use for 

125 validation. Default: 0. 

126 seed (int, optional): Seed value. Default: 42. 

127 shuffle (bool, optional): Shuffle training data per epoch. 

128 Default: True. 

129 initial_epoch (int, optional): Epoch at which to start training 

130 (useful for resuming a previous training run). Default: 1. 

131 validation_freq (int, optional): Number of training epochs before new 

132 validation. Default: 1. 

133 ''' 

134 batch_size: int = 32 

135 epochs: int = 30 

136 verbose: Union[str, int] = 'auto' 

137 validation_split: float = 0.0 

138 seed: int = 42 

139 shuffle: bool = True 

140 initial_epoch: int = 1 

141 validation_freq: int = 1 

142 

143 

144class LoggerConfig(BaseConfig): 

145 ''' 

146 Configuration for logger. 

147 

148 See: https://thenewflesh.github.io/flatiron/core.html#module-flatiron.core.logging 

149 

150 Attributes: 

151 slack_channel (str, optional): Slack channel name. Default: None. 

152 slack_url (str, optional): Slack URL name. Default: None. 

153 slack_methods (list[str], optional): Pipeline methods to be logged to 

154 Slack. Default: [load, compile, train]. 

155 timezone (str, optional): Timezone. Default: UTC. 

156 level (str or int, optional): Log level. Default: warn. 

157 ''' 

158 slack_channel: Optional[str] = None 

159 slack_url: Optional[str] = None 

160 slack_methods: list[str] = pyd.Field(default=['load', 'compile', 'train']) 

161 timezone: str = 'UTC' 

162 level: str = 'warn' 

163 

164 @pyd.field_validator('slack_methods') 

165 def _validate_slack_methods(cls, value): 

166 for item in value: 

167 vd.is_pipeline_method(item) 

168 return value 

169 

170 

171class PipelineConfig(BaseConfig): 

172 ''' 

173 Configuration for PipelineBase classes. 

174 

175 See: https://thenewflesh.github.io/flatiron/core.html#module-flatiron.core.pipeline 

176 

177 Attributes: 

178 framework (dict): Deep learning framework config. 

179 dataset (dict): Dataset configuration. 

180 optimizer (dict): Optimizer configuration. 

181 loss (dict): Loss configuration. 

182 metrics (list[dict], optional): Metric dicts. Default=[dict(name='Mean')]. 

183 compile (dict): Compile configuration. 

184 callbacks (dict): Callbacks configuration. 

185 logger (dict): Logger configuration. 

186 train (dict): Train configuration. 

187 ''' 

188 framework: FrameworkConfig 

189 dataset: DatasetConfig 

190 optimizer: OptimizerConfig 

191 loss: LossConfig 

192 metrics: list[Getter] = [dict(name='Mean')] 

193 callbacks: CallbacksConfig 

194 logger: LoggerConfig 

195 train: TrainConfig 

196 

197 @pyd.field_validator('metrics') 

198 def _validate_metrics(cls, items): 

199 for item in items: 

200 if 'name' not in item.keys(): 

201 msg = f'All dicts must contain name key. Given value: {item}.' 

202 raise ValueError(msg) 

203 return items