Coverage for /home/ubuntu/flatiron/python/flatiron/torch/models/dummy.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1import pydantic as pyd 

2import torch 

3import torch.nn as nn 

4 

5import flatiron.core.pipeline as ficp 

6# ------------------------------------------------------------------------------ 

7 

8 

9class DummyModel(nn.Module): 

10 def __init__(self, input_channels, output_channels): 

11 super().__init__() 

12 self.layer_stack = nn.Sequential( 

13 nn.Conv2d( 

14 in_channels=input_channels, out_channels=output_channels, 

15 kernel_size=(3, 3), dtype=torch.float16, padding=1 

16 ), 

17 nn.ReLU(), 

18 ) 

19 

20 def forward(self, x): 

21 return self.layer_stack(x) 

22 

23 

24def get_dummy_model(input_channels=3, output_channels=1): 

25 return DummyModel( 

26 input_channels=input_channels, 

27 output_channels=output_channels, 

28 ) 

29 

30 

31class DummyConfig(pyd.BaseModel): 

32 input_channels: int 

33 output_channels: int 

34 

35 

36class DummyPipeline(ficp.PipelineBase): 

37 def model_config(self): 

38 return DummyConfig 

39 

40 def model_func(self): 

41 return get_dummy_model