Coverage for /home/ubuntu/flatiron/python/flatiron/tf/models/unet.py: 85%

110 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing_extensions import Annotated 

2from tensorflow import keras # noqa F401 

3from keras import KerasTensor # noqa F401 

4 

5from lunchbox.enforce import Enforce 

6from keras import layers as tfl 

7from keras import models as tfmodels 

8import pydantic as pyd 

9 

10import flatiron.core.pipeline as ficp 

11import flatiron.core.tools as fict 

12import flatiron.core.validators as vd 

13# ------------------------------------------------------------------------------ 

14 

15 

16PAD = 18 

17 

18 

19# FUNCS------------------------------------------------------------------------- 

20def unet_width_and_layers_are_valid(width, layers): 

21 # type: (int, int) -> bool 

22 ''' 

23 Determines whether given UNet width and layers are valid. 

24 

25 Args: 

26 width (int): UNet input width. 

27 layers (int): Number of UNet layers. 

28 

29 Returns: 

30 bool: True if width and layers are compatible. 

31 ''' 

32 layers = int((layers - 1) / 2) - 1 

33 x = float(width) 

34 for _ in range(layers): 

35 x /= 2 

36 if x % 2 != 0: 

37 return False 

38 return True 

39 

40 

41def conv_2d_block( 

42 input_, # type: KerasTensor 

43 filters=16, # type: int 

44 activation='relu', # type: str 

45 batch_norm=True, # type: bool 

46 kernel_initializer='he_normal', # type: str 

47 name='conv-2d-block', # type: str 

48 dtype='float16', # type: str 

49 data_format='channels_last', # type: str 

50): 

51 # type: (...) -> KerasTensor 

52 r''' 

53 2D Convolution block without padding. 

54 

55 .. math:: 

56 :nowrap: 

57 

58 \begin{align} 

59 architecture & \rightarrow Conv2D + ReLU + BatchNorm + Conv2D 

60 + ReLU + BatchNorm \\ 

61 kernel & \rightarrow (3, 3) \\ 

62 strides & \rightarrow (1, 1) \\ 

63 padding & \rightarrow same \\ 

64 \end{align} 

65 

66 .. image:: images/conv_2d_block.svg 

67 :width: 800 

68 

69 Args: 

70 input_ (KerasTensor): Input tensor. 

71 filters (int, optional): Default: 16. 

72 activation (str, optional): Activation function. Default: relu. 

73 batch_norm (str, bool): Default: True. 

74 kernel_initializer (str, optional): Default: he_normal. 

75 name (str, optional): Layer name. Default: conv-2d-block 

76 dtype (str, optional): Model dtype. Default: float16. 

77 data_format (str, optional): Model data format. Default: channels_last. 

78 

79 Returns: 

80 KerasTensor: Conv2D Block 

81 ''' 

82 name = fict.pad_layer_name(name, length=PAD) 

83 kwargs = dict( 

84 filters=filters, 

85 kernel_size=(3, 3), 

86 strides=(1, 1), 

87 activation=activation, 

88 kernel_initializer=kernel_initializer, 

89 padding='same', 

90 use_bias=not batch_norm, 

91 dtype=dtype, 

92 data_format=data_format, 

93 ) 

94 

95 name2 = f'{name}-1' 

96 conv_1 = tfl.Conv2D(**kwargs, name=f'{name}-0')(input_) 

97 if batch_norm: 

98 conv_1 = tfl.BatchNormalization(name=f'{name}-1', dtype=dtype)(conv_1) 

99 name2 = f'{name}-2' 

100 

101 conv_2 = tfl.Conv2D(**kwargs, name=name2)(conv_1) 

102 if batch_norm: 

103 conv_2 = tfl.BatchNormalization(name=f'{name}-3', dtype=dtype)(conv_2) 

104 

105 return conv_2 

106 

107 

108def attention_gate_2d( 

109 query, # type: KerasTensor 

110 skip_connection, # type: KerasTensor 

111 activation_1='relu', # type: str 

112 activation_2='sigmoid', # type: str 

113 kernel_size=1, # type: int 

114 strides=1, # type: int 

115 padding='same', # type: str 

116 kernel_initializer='he_normal', # type: str 

117 name='attention-gate', # type: str 

118 dtype='float16', # type: str 

119 data_format='channels_last', # type: str 

120): 

121 # type: (...) -> KerasTensor 

122 ''' 

123 Attention gate for 2D inputs. 

124 See: https://arxiv.org/abs/1804.03999 

125 

126 Args: 

127 query (KerasTensor): 2D Tensor of query. 

128 skip_connection (KerasTensor): 2D Tensor of features. 

129 activation_1 (str, optional): First activation. Default: 'relu' 

130 activation_2 (str, optional): Second activation. Default: 'sigmoid' 

131 kernel_size (int, optional): Kernel_size. Default: 1 

132 strides (int, optional): Strides. Default: 1 

133 padding (str, optional): Padding. Default: 'same' 

134 kernel_initializer (str, optional): Kernel initializer. 

135 Default: 'he_normal' 

136 name (str, optional): Layer name. Default: attention-gate 

137 dtype (str, optional): Model dtype. Default: float16. 

138 data_format (str, optional): Model data format. Default: channels_last. 

139 

140 Returns: 

141 KerasTensor: 2D Attention Gate. 

142 ''' 

143 name = fict.pad_layer_name(name, length=PAD) 

144 filters = query.get_shape().as_list()[-1] 

145 kwargs = dict( 

146 kernel_size=kernel_size, 

147 strides=strides, 

148 padding=padding, 

149 kernel_initializer=kernel_initializer, 

150 dtype=dtype, 

151 data_format=data_format, 

152 ) 

153 conv_0 = tfl.Conv2D( 

154 filters=filters, **kwargs, name=f'{name}-0' 

155 )(skip_connection) 

156 conv_1 = tfl.Conv2D(filters=filters, **kwargs, name=f'{name}-1')(query) 

157 gate = tfl.add([conv_0, conv_1], name=f'{name}-2', dtype=dtype) 

158 gate = tfl.Activation(activation_1, name=f'{name}-3', dtype=dtype)(gate) 

159 gate = tfl.Conv2D(filters=1, **kwargs, name=f'{name}-4')(gate) 

160 gate = tfl.Activation(activation_2, name=f'{name}-5', dtype=dtype)(gate) 

161 gate = tfl.multiply([skip_connection, gate], name=f'{name}-6', dtype=dtype) 

162 output = tfl.concatenate([gate, query], name=f'{name}-7', dtype=dtype) 

163 return output 

164 

165 

166def get_unet_model( 

167 input_width, # type: int 

168 input_height, # type: int 

169 input_channels, # type: int 

170 classes=1, # type: int 

171 filters=32, # type: int 

172 layers=9, # type: int 

173 activation='leaky_relu', # type: str 

174 batch_norm=True, # type: bool 

175 output_activation='sigmoid', # type: str 

176 kernel_initializer='he_normal', # type: str 

177 attention_gates=False, # type: bool 

178 attention_activation_1='relu', # type: str 

179 attention_activation_2='sigmoid', # type: str 

180 attention_kernel_size=1, # type: int 

181 attention_strides=1, # type: int 

182 attention_padding='same', # type: str 

183 attention_kernel_initializer='he_normal', # type: str 

184 dtype='float16', # type: str 

185 data_format='channels_last', # type: str 

186): 

187 # type: (...) -> tfmodels.Model 

188 ''' 

189 UNet model for 2D semantic segmentation. 

190 

191 see: https://arxiv.org/abs/1505.04597 

192 see: https://arxiv.org/pdf/1411.4280.pdf 

193 see: https://arxiv.org/abs/1804.03999 

194 

195 Args: 

196 input_width (int): Input width. 

197 input_height (int): Input height. 

198 input_channels (int): Input channels. 

199 classes (int, optional): Number of output classes. Default: 1. 

200 filters (int, optional): Number of filters for initial con 2d block. 

201 Default: 32. 

202 layers (int, optional): Total number of layers. Default: 9. 

203 activation (KerasTensor, optional): Activation function to be used. 

204 Default: leaky_relu. 

205 batch_norm (KerasTensor, optional): Use batch normalization. 

206 Default: True. 

207 output_activation (KerasTensor, optional): Output activation function. 

208 Default: sigmoid. 

209 kernel_initializer (KerasTensor, optional): Default: he_normal. 

210 attention_gates (KerasTensor, optional): Use attention gates. 

211 Default: False. 

212 attention_activation_1 (str, optional): First activation. 

213 Default: 'relu' 

214 attention_activation_2 (str, optional): Second activation. 

215 Default: 'sigmoid' 

216 attention_kernel_size (int, optional): Kernel_size. Default: 1 

217 attention_strides (int, optional): Strides. Default: 1 

218 attention_padding (str, optional): Padding. Default: 'same' 

219 attention_kernel_initializer (str, optional): Kernel initializer. 

220 Default: 'he_normal' 

221 dtype (str, optional): Model dtype. Default: float16. 

222 data_format (str, optional): Model data format. Default: channels_last. 

223 

224 Raises: 

225 EnforceError: If input_width is not even. 

226 EnforceError: If input_height is not even. 

227 EnforceError: If layers is not an odd integer greater than 2. 

228 EnforceError: If input_width and layers are not compatible. 

229 

230 Returns: 

231 tfmodels.Model: UNet model. 

232 ''' 

233 # shape 

234 msg = 'Input width and height must be equal, even numbers. ' 

235 msg += f'Given shape: ({input_width}, {input_height}).' 

236 Enforce(input_width % 2, '==', 0, message=msg) 

237 Enforce(input_height % 2, '==', 0, message=msg) 

238 Enforce(input_width, '==', input_height, message=msg) 

239 

240 # layers 

241 msg = 'Layers must be an odd integer greater than 2. ' 

242 msg += f'Given value: {layers}.' 

243 Enforce(layers, 'instance of', int, message=msg) 

244 Enforce(layers, '>=', 3, message=msg) 

245 Enforce(layers % 2, '==', 1, message=msg) 

246 

247 # valid width and layers 

248 msg = 'Given input_width and layers are not compatible. ' 

249 msg += f'Input_width: {input_width}. Layers: {layers}.' 

250 Enforce( 

251 unet_width_and_layers_are_valid(input_width, layers), '==', True, message=msg 

252 ) 

253 # -------------------------------------------------------------------------- 

254 

255 n = int((layers - 1) / 2) 

256 encode_layers = [] 

257 

258 # input layer 

259 shape = (input_width, input_height, input_channels) 

260 input_ = tfl.Input(shape, name='input', dtype=dtype) 

261 

262 # encode layers 

263 x = input_ 

264 for i in range(n): 

265 # conv backend of layer 

266 x = conv_2d_block( 

267 input_=x, 

268 filters=filters, 

269 batch_norm=batch_norm, 

270 activation=activation, 

271 kernel_initializer=kernel_initializer, 

272 name=f'encode-block_{i:02d}', 

273 dtype=dtype, 

274 data_format=data_format, 

275 ) 

276 encode_layers.append(x) 

277 

278 # downsample 

279 name = fict.pad_layer_name(f'downsample_{i:02d}', length=PAD) 

280 x = tfl.MaxPooling2D( 

281 (2, 2), name=name, dtype=dtype, data_format=data_format, 

282 )(x) 

283 filters *= 2 

284 

285 # middle layer 

286 name = fict.pad_layer_name('middle-block_00', length=PAD) 

287 x = conv_2d_block( 

288 input_=x, 

289 filters=filters, 

290 batch_norm=batch_norm, 

291 activation=activation, 

292 kernel_initializer=kernel_initializer, 

293 name=name, 

294 dtype=dtype, 

295 ) 

296 

297 # decode layers 

298 decode_layers = list(reversed(encode_layers)) 

299 for i, layer in enumerate(decode_layers[:n]): 

300 filters = int(filters / 2) 

301 

302 # upsample 

303 name = fict.pad_layer_name(f'upsample_{i:02d}', length=PAD) 

304 x = tfl.Conv2DTranspose( 

305 filters=filters, 

306 kernel_size=(2, 2), 

307 strides=(2, 2), 

308 padding='same', 

309 name=name, 

310 dtype=dtype, 

311 data_format=data_format, 

312 )(x) 

313 

314 # attention gate 

315 if attention_gates: 

316 name = fict.pad_layer_name(f'attention-gate_{i:02d}', length=PAD) 

317 x = attention_gate_2d( 

318 x, 

319 layer, 

320 activation_1=attention_activation_1, 

321 activation_2=attention_activation_2, 

322 kernel_size=attention_kernel_size, 

323 strides=attention_strides, 

324 padding=attention_padding, 

325 kernel_initializer=attention_kernel_initializer, 

326 name=name, 

327 dtype=dtype, 

328 data_format=data_format, 

329 ) 

330 else: 

331 name = fict.pad_layer_name(f'concat_{i:02d}', length=PAD) 

332 x = tfl.concatenate([layer, x], name=name, dtype=dtype) 

333 

334 # conv backend of layer 

335 x = conv_2d_block( 

336 input_=x, 

337 filters=filters, 

338 batch_norm=batch_norm, 

339 activation=activation, 

340 kernel_initializer=kernel_initializer, 

341 name=f'decode-block_{i:02d}', 

342 dtype=dtype, 

343 data_format=data_format, 

344 ) 

345 

346 output = tfl.Conv2D( 

347 classes, (1, 1), activation=output_activation, name='output', 

348 dtype=dtype, data_format=data_format, 

349 )(x) 

350 model = tfmodels.Model(inputs=[input_], outputs=[output]) 

351 return model 

352 

353 

354# CONFIG------------------------------------------------------------------------ 

355class UNetConfig(pyd.BaseModel): 

356 ''' 

357 Configuration for UNet model. 

358 

359 Attributes: 

360 input_width (int): Input width. 

361 input_height (int): Input height. 

362 input_channels (int): Input channels. 

363 classes (int, optional): Number of output classes. Default: 1. 

364 filters (int, optional): Number of filters for initial con 2d block. 

365 Default: 16. 

366 layers (int, optional): Total number of layers. Default: 9. 

367 activation (KerasTensor, optional): Activation function to be used. 

368 Default: relu. 

369 batch_norm (KerasTensor, optional): Use batch normalization. 

370 Default: True. 

371 output_activation (KerasTensor, optional): Output activation function. 

372 Default: sigmoid. 

373 kernel_initializer (KerasTensor, optional): Default: he_normal. 

374 attention_gates (KerasTensor, optional): Use attention gates. 

375 Default: False. 

376 attention_activation_1 (str, optional): First activation. 

377 Default: 'relu' 

378 attention_activation_2 (str, optional): Second activation. 

379 Default: 'sigmoid' 

380 attention_kernel_size (int, optional): Kernel_size. Default: 1 

381 attention_strides (int, optional): Strides. Default: 1 

382 attention_padding (str, optional): Padding. Default: 'same' 

383 attention_kernel_initializer (str, optional): Kernel initializer. 

384 Default: 'he_normal' 

385 ''' 

386 input_width: Annotated[int, pyd.Field(ge=1)] 

387 input_height: Annotated[int, pyd.Field(ge=1)] 

388 input_channels: Annotated[int, pyd.Field(ge=1)] 

389 classes: Annotated[int, pyd.Field(ge=1)] = 1 

390 filters: Annotated[int, pyd.Field(ge=1)] = 16 

391 layers: Annotated[int, pyd.Field(ge=3), pyd.AfterValidator(vd.is_odd)] = 9 

392 activation: str = 'relu' 

393 batch_norm: bool = True 

394 output_activation: str = 'sigmoid' 

395 kernel_initializer: str = 'he_normal' 

396 attention_gates: bool = False 

397 attention_activation_1: str = 'relu' 

398 attention_activation_2: str = 'sigmoid' 

399 attention_kernel_size: Annotated[int, pyd.Field(ge=1)] = 1 

400 attention_strides: Annotated[int, pyd.Field(ge=1)] = 1 

401 attention_padding: Annotated[str, pyd.AfterValidator(vd.is_padding)] = 'same' 

402 attention_kernel_initializer: str = 'he_normal' 

403 dtype: str = 'float16' 

404 data_format: str = 'channels_last' 

405 

406 

407# PIPELINE---------------------------------------------------------------------- 

408class UNetPipeline(ficp.PipelineBase): 

409 def model_config(self): 

410 # type: () -> type[UNetConfig] 

411 return UNetConfig 

412 

413 def model_func(self): 

414 # type: () -> tfmodels.Model 

415 return get_unet_model