Coverage for /home/ubuntu/flatiron/python/flatiron/torch/models/unet.py: 40%

101 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Callable # noqa F401 

2from typing_extensions import Annotated 

3 

4import pydantic as pyd 

5 

6import torch 

7import torch.nn as nn 

8 

9import flatiron.core.validators as vd 

10import flatiron.core.pipeline as ficp 

11# ------------------------------------------------------------------------------ 

12 

13 

14class Conv2DBlock(nn.Module): 

15 def __init__(self, in_channels, filters=16, dtype=torch.float16): 

16 super().__init__() 

17 kwargs = dict( 

18 out_channels=filters, kernel_size=(3, 3), 

19 stride=(1, 1), padding=1, padding_mode='reflect', dtype=dtype 

20 ) 

21 self.conv_1 = nn.Conv2d(in_channels=in_channels, **kwargs) 

22 self.act_1 = nn.ReLU() 

23 self.batch_1 = nn.BatchNorm2d(filters, dtype=dtype) 

24 self.act_1 = nn.Sigmoid() 

25 self.conv_2 = nn.Conv2d(in_channels=filters, **kwargs) 

26 self.act_2 = nn.ReLU() 

27 self.batch_2 = nn.BatchNorm2d(filters, dtype=dtype) 

28 

29 def forward(self, x): 

30 x = self.conv_1(x) 

31 x = self.act_1(x) 

32 x = self.batch_1(x) 

33 x = self.conv_2(x) 

34 x = self.act_2(x) 

35 x = self.batch_2(x) 

36 return x 

37 

38 

39class AtttentionGate2DBlock(nn.Module): 

40 def __init__(self, in_channels, filters=16, dtype=torch.float16): 

41 super().__init__() 

42 kwargs = dict( 

43 kernel_size=(3, 3), 

44 stride=(1, 1), padding=1, padding_mode='reflect', dtype=dtype 

45 ) 

46 self.conv_0 = nn.Conv2d(in_channels=in_channels, out_channels=filters, **kwargs) 

47 self.conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=filters, **kwargs) 

48 self.act_1 = nn.ReLU() 

49 self.conv_2 = nn.Conv2d(in_channels=filters, out_channels=1, **kwargs) 

50 self.act_1 = nn.Sigmoid() 

51 

52 def forward(self, skip_connection, query): 

53 skip = self.conv_0(skip_connection) 

54 query = self.conv_1(query) 

55 

56 gate = torch.add(skip, query) 

57 gate = self.act_1(gate) 

58 gate = self.conv_2(gate) 

59 gate = self.act_2(gate) 

60 gate = torch.multiply(skip, gate) 

61 

62 x = torch.concatenate([gate, query]) 

63 return x 

64 

65 

66class UNet(nn.Module): 

67 def __init__(self, in_channels=3, out_channels=1, attention=False, dtype=torch.float16): 

68 super().__init__() 

69 self._attention = attention 

70 kwargs = dict(dtype=dtype) 

71 pool_kwargs = dict(kernel_size=(2, 2), stride=(2, 2)) 

72 trans_kwargs = dict(kernel_size=(2, 2), stride=(2, 2), dtype=dtype) 

73 

74 self.encode_block_00 = Conv2DBlock(in_channels=in_channels, filters=16, **kwargs) 

75 self.downsample_00 = nn.MaxPool2d(**pool_kwargs) 

76 self.encode_block_01 = Conv2DBlock(in_channels=16, filters=32, **kwargs) 

77 self.downsample_01 = nn.MaxPool2d(**pool_kwargs) 

78 

79 self.middle_block = Conv2DBlock(in_channels=32, filters=64, **kwargs) 

80 

81 self.upsample_00 = nn.ConvTranspose2d(in_channels=64, out_channels=32, **trans_kwargs) 

82 self.decode_block_00 = Conv2DBlock(in_channels=64, filters=32, **kwargs) 

83 self.upsample_01 = nn.ConvTranspose2d(in_channels=32, out_channels=16, **trans_kwargs) 

84 self.decode_block_01 = Conv2DBlock(in_channels=32, filters=out_channels, **kwargs) 

85 

86 def forward(self, x): 

87 x0 = self.encode_block_00(x) 

88 x = self.downsample_00(x0) 

89 x1 = self.encode_block_01(x) 

90 x = self.downsample_01(x1) 

91 x = self.middle_block(x) 

92 

93 x = self.upsample_00(x) 

94 x = torch.concatenate([x, x1], axis=1) 

95 x = self.decode_block_00(x) 

96 x = self.upsample_01(x) 

97 x = torch.concatenate([x, x0], axis=1) 

98 x = self.decode_block_01(x) 

99 return x 

100 

101 

102def get_unet_model(in_channels, out_channels=3, dtype='float16'): 

103 return UNet(in_channels, out_channels, dtype) 

104 

105 

106# CONFIG------------------------------------------------------------------------ 

107class UNetConfig(pyd.BaseModel): 

108 ''' 

109 Configuration for UNet model. 

110 

111 Attributes: 

112 input_width (int): Input width. 

113 input_height (int): Input height. 

114 input_channels (int): Input channels. 

115 classes (int, optional): Number of output classes. Default: 1. 

116 filters (int, optional): Number of filters for initial con 2d block. 

117 Default: 16. 

118 layers (int, optional): Total number of layers. Default: 9. 

119 activation (KerasTensor, optional): Activation function to be used. 

120 Default: relu. 

121 batch_norm (KerasTensor, optional): Use batch normalization. 

122 Default: True. 

123 output_activation (KerasTensor, optional): Output activation function. 

124 Default: sigmoid. 

125 kernel_initializer (KerasTensor, optional): Default: he_normal. 

126 attention_gates (KerasTensor, optional): Use attention gates. 

127 Default: False. 

128 attention_activation_1 (str, optional): First activation. 

129 Default: 'relu' 

130 attention_activation_2 (str, optional): Second activation. 

131 Default: 'sigmoid' 

132 attention_kernel_size (int, optional): Kernel_size. Default: 1 

133 attention_strides (int, optional): Strides. Default: 1 

134 attention_padding (str, optional): Padding. Default: 'same' 

135 attention_kernel_initializer (str, optional): Kernel initializer. 

136 Default: 'he_normal' 

137 ''' 

138 input_width: Annotated[int, pyd.Field(ge=1)] 

139 input_height: Annotated[int, pyd.Field(ge=1)] 

140 input_channels: Annotated[int, pyd.Field(ge=1)] 

141 classes: Annotated[int, pyd.Field(ge=1)] = 1 

142 filters: Annotated[int, pyd.Field(ge=1)] = 16 

143 layers: Annotated[int, pyd.Field(ge=3), pyd.AfterValidator(vd.is_odd)] = 9 

144 activation: str = 'relu' 

145 batch_norm: bool = True 

146 output_activation: str = 'sigmoid' 

147 kernel_initializer: str = 'he_normal' 

148 attention_gates: bool = False 

149 attention_activation_1: str = 'relu' 

150 attention_activation_2: str = 'sigmoid' 

151 attention_kernel_size: Annotated[int, pyd.Field(ge=1)] = 1 

152 attention_strides: Annotated[int, pyd.Field(ge=1)] = 1 

153 attention_padding: Annotated[str, pyd.AfterValidator(vd.is_padding)] = 'same' 

154 attention_kernel_initializer: str = 'he_normal' 

155 dtype: str = 'float16' 

156 data_format: str = 'channels_last' 

157 

158 

159# PIPELINE---------------------------------------------------------------------- 

160class UNetPipeline(ficp.PipelineBase): 

161 def model_config(self): 

162 # type: () -> type[UNetConfig] 

163 return UNetConfig 

164 

165 def model_func(self): 

166 # type: () -> Callable[..., nn.Module] 

167 return get_unet_model