Coverage for /home/ubuntu/flatiron/python/flatiron/tf/models/dummy.py: 100%
18 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from tensorflow import keras # noqa F401
2from keras import layers as tfl
3from keras import models as tfmodels
4import pydantic as pyd
6import flatiron.core.pipeline as ficp
7# ------------------------------------------------------------------------------
10def get_dummy_model(shape, activation='relu'):
11 input_ = tfl.Input(shape, name='input')
12 output = tfl.Conv2D(1, (1, 1), activation=activation, name='output')(input_)
13 model = tfmodels.Model(inputs=[input_], outputs=[output])
14 return model
17class DummyConfig(pyd.BaseModel):
18 shape: list[int]
19 activation: str = 'relu'
22class DummyPipeline(ficp.PipelineBase):
23 def model_config(self):
24 return DummyConfig
26 def model_func(self):
27 return get_dummy_model