Coverage for /home/ubuntu/flatiron/python/flatiron/core/config.py: 100%
70 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Optional, Union
2from typing_extensions import Annotated
4from flatiron.core.types import OptLabels, OptFloat, Getter
6import pydantic as pyd
8import flatiron.core.validators as vd
9# ------------------------------------------------------------------------------
12class BaseConfig(pyd.BaseModel):
13 model_config = pyd.ConfigDict(extra='forbid')
16class DatasetConfig(BaseConfig):
17 '''
18 Configuration for Dataset.
20 See: https://thenewflesh.github.io/flatiron/core.html#module-flatiron.core.dataset
22 Attributes:
23 source (str): Dataset directory or CSV filepath.
24 ext_regex (str, optional): File extension pattern.
25 Default: 'npy|exr|png|jpeg|jpg|tiff'.
26 labels (object, optional): Label channels. Default: None.
27 label_axis (int, optional): Label axis. Default: -1.
28 test_size (float, optional): Test set size as a proportion.
29 Default: 0.2.
30 limit (str or int): Limit data by number of samples.
31 Default: None.
32 reshape (bool, optional): Reshape concatenated data to incorpate frames
33 as the first dimension: (FRAME, ...). Analogous to the first
34 dimension being batch. Default: True.
35 shuffle (bool, optional): Randomize data before splitting.
36 Default: True.
37 seed (int, optional): Shuffle seed number. Default: None.
38 '''
39 source: str
40 ext_regex: str = 'npy|exr|png|jpeg|jpg|tiff'
41 labels: OptLabels = None
42 label_axis: int = -1
43 test_size: Optional[Annotated[float, pyd.Field(ge=0)]] = 0.2
44 limit: Optional[Annotated[int, pyd.Field(ge=0)]] = None
45 reshape: bool = True
46 shuffle: bool = True
47 seed: Optional[int] = None
50class FrameworkConfig(pyd.BaseModel):
51 name: Annotated[str, pyd.AfterValidator(vd.is_engine)] = 'tensorflow'
52 device: str = 'cpu'
55class OptimizerConfig(pyd.BaseModel):
56 '''
57 Configuration for optimizer.
59 See: https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer
61 Attributes:
62 name (string, optional): Name of optimizer. Default='SGD'.
63 '''
64 name: str = 'SGD'
67class LossConfig(pyd.BaseModel):
68 '''
69 Configuration for loss.
71 Attributes:
72 name (string, optional): Name of loss. Default='MeanSquaredError'.
73 '''
74 name: str = 'MeanSquaredError'
77class CallbacksConfig(BaseConfig):
78 '''
79 Configuration for callbacks.
81 See: https://thenewflesh.github.io/flatiron/core.html#module-flatiron.core.tools
82 See: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
84 Attributes:
85 project (str): Name of project.
86 root (str or Path): Tensorboard parent directory. Default: /mnt/storage.
87 monitor (str, optional): Metric to monitor. Default: 'val_loss'.
88 verbose (int, optional): Log callback actions. Default: 0.
89 save_best_only (bool, optional): Save only best model. Default: False.
90 mode (str, optional): Overwrite best model via
91 `mode(old metric, new metric)`. Options: [auto, min, max].
92 Default: 'auto'.
93 save_weights_only (bool, optional): Only save model weights.
94 Default: False.
95 save_freq (union, optional): Save after each epoch or N batches.
96 Options: 'epoch' or int. Default: 'epoch'.
97 initial_value_threshold (float, optional): Initial best value of metric.
98 Default: None.
99 '''
100 project: str
101 root: str
102 monitor: str = 'val_loss'
103 verbose: int = 0
104 save_best_only: bool = False
105 save_weights_only: bool = False
106 mode: Annotated[str, pyd.AfterValidator(vd.is_callback_mode)] = 'auto'
107 save_freq: Union[str, int] = 'epoch'
108 initial_value_threshold: OptFloat = None
111class TrainConfig(BaseConfig):
112 '''
113 Configuration for calls to model train function.
115 See: https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
117 Attributes:
118 batch_size (int, optional): Number of samples per update. Default: 32.
119 epochs (int, optional): Number of epochs to train model. Default: 30.
120 verbose (str or int, optional): Verbosity of model logging.
121 Options: 'auto', 0, 1, 2.
122 0 is silent. 1 is progress bar. 2 is one line per epoch.
123 Auto is usually 1. Default: auto.
124 validation_split (float, optional): Fraction of training data to use for
125 validation. Default: 0.
126 seed (int, optional): Seed value. Default: 42.
127 shuffle (bool, optional): Shuffle training data per epoch.
128 Default: True.
129 initial_epoch (int, optional): Epoch at which to start training
130 (useful for resuming a previous training run). Default: 1.
131 validation_freq (int, optional): Number of training epochs before new
132 validation. Default: 1.
133 '''
134 batch_size: int = 32
135 epochs: int = 30
136 verbose: Union[str, int] = 'auto'
137 validation_split: float = 0.0
138 seed: int = 42
139 shuffle: bool = True
140 initial_epoch: int = 1
141 validation_freq: int = 1
144class LoggerConfig(BaseConfig):
145 '''
146 Configuration for logger.
148 See: https://thenewflesh.github.io/flatiron/core.html#module-flatiron.core.logging
150 Attributes:
151 slack_channel (str, optional): Slack channel name. Default: None.
152 slack_url (str, optional): Slack URL name. Default: None.
153 slack_methods (list[str], optional): Pipeline methods to be logged to
154 Slack. Default: [load, compile, train].
155 timezone (str, optional): Timezone. Default: UTC.
156 level (str or int, optional): Log level. Default: warn.
157 '''
158 slack_channel: Optional[str] = None
159 slack_url: Optional[str] = None
160 slack_methods: list[str] = pyd.Field(default=['load', 'compile', 'train'])
161 timezone: str = 'UTC'
162 level: str = 'warn'
164 @pyd.field_validator('slack_methods')
165 def _validate_slack_methods(cls, value):
166 for item in value:
167 vd.is_pipeline_method(item)
168 return value
171class PipelineConfig(BaseConfig):
172 '''
173 Configuration for PipelineBase classes.
175 See: https://thenewflesh.github.io/flatiron/core.html#module-flatiron.core.pipeline
177 Attributes:
178 framework (dict): Deep learning framework config.
179 dataset (dict): Dataset configuration.
180 optimizer (dict): Optimizer configuration.
181 loss (dict): Loss configuration.
182 metrics (list[dict], optional): Metric dicts. Default=[dict(name='Mean')].
183 compile (dict): Compile configuration.
184 callbacks (dict): Callbacks configuration.
185 logger (dict): Logger configuration.
186 train (dict): Train configuration.
187 '''
188 framework: FrameworkConfig
189 dataset: DatasetConfig
190 optimizer: OptimizerConfig
191 loss: LossConfig
192 metrics: list[Getter] = [dict(name='Mean')]
193 callbacks: CallbacksConfig
194 logger: LoggerConfig
195 train: TrainConfig
197 @pyd.field_validator('metrics')
198 def _validate_metrics(cls, items):
199 for item in items:
200 if 'name' not in item.keys():
201 msg = f'All dicts must contain name key. Given value: {item}.'
202 raise ValueError(msg)
203 return items