Coverage for /home/ubuntu/flatiron/python/flatiron/torch/models/dummy.py: 100%
20 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1import pydantic as pyd
2import torch
3import torch.nn as nn
5import flatiron.core.pipeline as ficp
6# ------------------------------------------------------------------------------
9class DummyModel(nn.Module):
10 def __init__(self, input_channels, output_channels):
11 super().__init__()
12 self.layer_stack = nn.Sequential(
13 nn.Conv2d(
14 in_channels=input_channels, out_channels=output_channels,
15 kernel_size=(3, 3), dtype=torch.float16, padding=1
16 ),
17 nn.ReLU(),
18 )
20 def forward(self, x):
21 return self.layer_stack(x)
24def get_dummy_model(input_channels=3, output_channels=1):
25 return DummyModel(
26 input_channels=input_channels,
27 output_channels=output_channels,
28 )
31class DummyConfig(pyd.BaseModel):
32 input_channels: int
33 output_channels: int
36class DummyPipeline(ficp.PipelineBase):
37 def model_config(self):
38 return DummyConfig
40 def model_func(self):
41 return get_dummy_model