Coverage for /home/ubuntu/flatiron/python/flatiron/torch/models/unet.py: 40%
101 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Callable # noqa F401
2from typing_extensions import Annotated
4import pydantic as pyd
6import torch
7import torch.nn as nn
9import flatiron.core.validators as vd
10import flatiron.core.pipeline as ficp
11# ------------------------------------------------------------------------------
14class Conv2DBlock(nn.Module):
15 def __init__(self, in_channels, filters=16, dtype=torch.float16):
16 super().__init__()
17 kwargs = dict(
18 out_channels=filters, kernel_size=(3, 3),
19 stride=(1, 1), padding=1, padding_mode='reflect', dtype=dtype
20 )
21 self.conv_1 = nn.Conv2d(in_channels=in_channels, **kwargs)
22 self.act_1 = nn.ReLU()
23 self.batch_1 = nn.BatchNorm2d(filters, dtype=dtype)
24 self.act_1 = nn.Sigmoid()
25 self.conv_2 = nn.Conv2d(in_channels=filters, **kwargs)
26 self.act_2 = nn.ReLU()
27 self.batch_2 = nn.BatchNorm2d(filters, dtype=dtype)
29 def forward(self, x):
30 x = self.conv_1(x)
31 x = self.act_1(x)
32 x = self.batch_1(x)
33 x = self.conv_2(x)
34 x = self.act_2(x)
35 x = self.batch_2(x)
36 return x
39class AtttentionGate2DBlock(nn.Module):
40 def __init__(self, in_channels, filters=16, dtype=torch.float16):
41 super().__init__()
42 kwargs = dict(
43 kernel_size=(3, 3),
44 stride=(1, 1), padding=1, padding_mode='reflect', dtype=dtype
45 )
46 self.conv_0 = nn.Conv2d(in_channels=in_channels, out_channels=filters, **kwargs)
47 self.conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=filters, **kwargs)
48 self.act_1 = nn.ReLU()
49 self.conv_2 = nn.Conv2d(in_channels=filters, out_channels=1, **kwargs)
50 self.act_1 = nn.Sigmoid()
52 def forward(self, skip_connection, query):
53 skip = self.conv_0(skip_connection)
54 query = self.conv_1(query)
56 gate = torch.add(skip, query)
57 gate = self.act_1(gate)
58 gate = self.conv_2(gate)
59 gate = self.act_2(gate)
60 gate = torch.multiply(skip, gate)
62 x = torch.concatenate([gate, query])
63 return x
66class UNet(nn.Module):
67 def __init__(self, in_channels=3, out_channels=1, attention=False, dtype=torch.float16):
68 super().__init__()
69 self._attention = attention
70 kwargs = dict(dtype=dtype)
71 pool_kwargs = dict(kernel_size=(2, 2), stride=(2, 2))
72 trans_kwargs = dict(kernel_size=(2, 2), stride=(2, 2), dtype=dtype)
74 self.encode_block_00 = Conv2DBlock(in_channels=in_channels, filters=16, **kwargs)
75 self.downsample_00 = nn.MaxPool2d(**pool_kwargs)
76 self.encode_block_01 = Conv2DBlock(in_channels=16, filters=32, **kwargs)
77 self.downsample_01 = nn.MaxPool2d(**pool_kwargs)
79 self.middle_block = Conv2DBlock(in_channels=32, filters=64, **kwargs)
81 self.upsample_00 = nn.ConvTranspose2d(in_channels=64, out_channels=32, **trans_kwargs)
82 self.decode_block_00 = Conv2DBlock(in_channels=64, filters=32, **kwargs)
83 self.upsample_01 = nn.ConvTranspose2d(in_channels=32, out_channels=16, **trans_kwargs)
84 self.decode_block_01 = Conv2DBlock(in_channels=32, filters=out_channels, **kwargs)
86 def forward(self, x):
87 x0 = self.encode_block_00(x)
88 x = self.downsample_00(x0)
89 x1 = self.encode_block_01(x)
90 x = self.downsample_01(x1)
91 x = self.middle_block(x)
93 x = self.upsample_00(x)
94 x = torch.concatenate([x, x1], axis=1)
95 x = self.decode_block_00(x)
96 x = self.upsample_01(x)
97 x = torch.concatenate([x, x0], axis=1)
98 x = self.decode_block_01(x)
99 return x
102def get_unet_model(in_channels, out_channels=3, dtype='float16'):
103 return UNet(in_channels, out_channels, dtype)
106# CONFIG------------------------------------------------------------------------
107class UNetConfig(pyd.BaseModel):
108 '''
109 Configuration for UNet model.
111 Attributes:
112 input_width (int): Input width.
113 input_height (int): Input height.
114 input_channels (int): Input channels.
115 classes (int, optional): Number of output classes. Default: 1.
116 filters (int, optional): Number of filters for initial con 2d block.
117 Default: 16.
118 layers (int, optional): Total number of layers. Default: 9.
119 activation (KerasTensor, optional): Activation function to be used.
120 Default: relu.
121 batch_norm (KerasTensor, optional): Use batch normalization.
122 Default: True.
123 output_activation (KerasTensor, optional): Output activation function.
124 Default: sigmoid.
125 kernel_initializer (KerasTensor, optional): Default: he_normal.
126 attention_gates (KerasTensor, optional): Use attention gates.
127 Default: False.
128 attention_activation_1 (str, optional): First activation.
129 Default: 'relu'
130 attention_activation_2 (str, optional): Second activation.
131 Default: 'sigmoid'
132 attention_kernel_size (int, optional): Kernel_size. Default: 1
133 attention_strides (int, optional): Strides. Default: 1
134 attention_padding (str, optional): Padding. Default: 'same'
135 attention_kernel_initializer (str, optional): Kernel initializer.
136 Default: 'he_normal'
137 '''
138 input_width: Annotated[int, pyd.Field(ge=1)]
139 input_height: Annotated[int, pyd.Field(ge=1)]
140 input_channels: Annotated[int, pyd.Field(ge=1)]
141 classes: Annotated[int, pyd.Field(ge=1)] = 1
142 filters: Annotated[int, pyd.Field(ge=1)] = 16
143 layers: Annotated[int, pyd.Field(ge=3), pyd.AfterValidator(vd.is_odd)] = 9
144 activation: str = 'relu'
145 batch_norm: bool = True
146 output_activation: str = 'sigmoid'
147 kernel_initializer: str = 'he_normal'
148 attention_gates: bool = False
149 attention_activation_1: str = 'relu'
150 attention_activation_2: str = 'sigmoid'
151 attention_kernel_size: Annotated[int, pyd.Field(ge=1)] = 1
152 attention_strides: Annotated[int, pyd.Field(ge=1)] = 1
153 attention_padding: Annotated[str, pyd.AfterValidator(vd.is_padding)] = 'same'
154 attention_kernel_initializer: str = 'he_normal'
155 dtype: str = 'float16'
156 data_format: str = 'channels_last'
159# PIPELINE----------------------------------------------------------------------
160class UNetPipeline(ficp.PipelineBase):
161 def model_config(self):
162 # type: () -> type[UNetConfig]
163 return UNetConfig
165 def model_func(self):
166 # type: () -> Callable[..., nn.Module]
167 return get_unet_model