Coverage for /home/ubuntu/flatiron/python/flatiron/tf/models/unet.py: 85%
110 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing_extensions import Annotated
2from tensorflow import keras # noqa F401
3from keras import KerasTensor # noqa F401
5from lunchbox.enforce import Enforce
6from keras import layers as tfl
7from keras import models as tfmodels
8import pydantic as pyd
10import flatiron.core.pipeline as ficp
11import flatiron.core.tools as fict
12import flatiron.core.validators as vd
13# ------------------------------------------------------------------------------
16PAD = 18
19# FUNCS-------------------------------------------------------------------------
20def unet_width_and_layers_are_valid(width, layers):
21 # type: (int, int) -> bool
22 '''
23 Determines whether given UNet width and layers are valid.
25 Args:
26 width (int): UNet input width.
27 layers (int): Number of UNet layers.
29 Returns:
30 bool: True if width and layers are compatible.
31 '''
32 layers = int((layers - 1) / 2) - 1
33 x = float(width)
34 for _ in range(layers):
35 x /= 2
36 if x % 2 != 0:
37 return False
38 return True
41def conv_2d_block(
42 input_, # type: KerasTensor
43 filters=16, # type: int
44 activation='relu', # type: str
45 batch_norm=True, # type: bool
46 kernel_initializer='he_normal', # type: str
47 name='conv-2d-block', # type: str
48 dtype='float16', # type: str
49 data_format='channels_last', # type: str
50):
51 # type: (...) -> KerasTensor
52 r'''
53 2D Convolution block without padding.
55 .. math::
56 :nowrap:
58 \begin{align}
59 architecture & \rightarrow Conv2D + ReLU + BatchNorm + Conv2D
60 + ReLU + BatchNorm \\
61 kernel & \rightarrow (3, 3) \\
62 strides & \rightarrow (1, 1) \\
63 padding & \rightarrow same \\
64 \end{align}
66 .. image:: images/conv_2d_block.svg
67 :width: 800
69 Args:
70 input_ (KerasTensor): Input tensor.
71 filters (int, optional): Default: 16.
72 activation (str, optional): Activation function. Default: relu.
73 batch_norm (str, bool): Default: True.
74 kernel_initializer (str, optional): Default: he_normal.
75 name (str, optional): Layer name. Default: conv-2d-block
76 dtype (str, optional): Model dtype. Default: float16.
77 data_format (str, optional): Model data format. Default: channels_last.
79 Returns:
80 KerasTensor: Conv2D Block
81 '''
82 name = fict.pad_layer_name(name, length=PAD)
83 kwargs = dict(
84 filters=filters,
85 kernel_size=(3, 3),
86 strides=(1, 1),
87 activation=activation,
88 kernel_initializer=kernel_initializer,
89 padding='same',
90 use_bias=not batch_norm,
91 dtype=dtype,
92 data_format=data_format,
93 )
95 name2 = f'{name}-1'
96 conv_1 = tfl.Conv2D(**kwargs, name=f'{name}-0')(input_)
97 if batch_norm:
98 conv_1 = tfl.BatchNormalization(name=f'{name}-1', dtype=dtype)(conv_1)
99 name2 = f'{name}-2'
101 conv_2 = tfl.Conv2D(**kwargs, name=name2)(conv_1)
102 if batch_norm:
103 conv_2 = tfl.BatchNormalization(name=f'{name}-3', dtype=dtype)(conv_2)
105 return conv_2
108def attention_gate_2d(
109 query, # type: KerasTensor
110 skip_connection, # type: KerasTensor
111 activation_1='relu', # type: str
112 activation_2='sigmoid', # type: str
113 kernel_size=1, # type: int
114 strides=1, # type: int
115 padding='same', # type: str
116 kernel_initializer='he_normal', # type: str
117 name='attention-gate', # type: str
118 dtype='float16', # type: str
119 data_format='channels_last', # type: str
120):
121 # type: (...) -> KerasTensor
122 '''
123 Attention gate for 2D inputs.
124 See: https://arxiv.org/abs/1804.03999
126 Args:
127 query (KerasTensor): 2D Tensor of query.
128 skip_connection (KerasTensor): 2D Tensor of features.
129 activation_1 (str, optional): First activation. Default: 'relu'
130 activation_2 (str, optional): Second activation. Default: 'sigmoid'
131 kernel_size (int, optional): Kernel_size. Default: 1
132 strides (int, optional): Strides. Default: 1
133 padding (str, optional): Padding. Default: 'same'
134 kernel_initializer (str, optional): Kernel initializer.
135 Default: 'he_normal'
136 name (str, optional): Layer name. Default: attention-gate
137 dtype (str, optional): Model dtype. Default: float16.
138 data_format (str, optional): Model data format. Default: channels_last.
140 Returns:
141 KerasTensor: 2D Attention Gate.
142 '''
143 name = fict.pad_layer_name(name, length=PAD)
144 filters = query.get_shape().as_list()[-1]
145 kwargs = dict(
146 kernel_size=kernel_size,
147 strides=strides,
148 padding=padding,
149 kernel_initializer=kernel_initializer,
150 dtype=dtype,
151 data_format=data_format,
152 )
153 conv_0 = tfl.Conv2D(
154 filters=filters, **kwargs, name=f'{name}-0'
155 )(skip_connection)
156 conv_1 = tfl.Conv2D(filters=filters, **kwargs, name=f'{name}-1')(query)
157 gate = tfl.add([conv_0, conv_1], name=f'{name}-2', dtype=dtype)
158 gate = tfl.Activation(activation_1, name=f'{name}-3', dtype=dtype)(gate)
159 gate = tfl.Conv2D(filters=1, **kwargs, name=f'{name}-4')(gate)
160 gate = tfl.Activation(activation_2, name=f'{name}-5', dtype=dtype)(gate)
161 gate = tfl.multiply([skip_connection, gate], name=f'{name}-6', dtype=dtype)
162 output = tfl.concatenate([gate, query], name=f'{name}-7', dtype=dtype)
163 return output
166def get_unet_model(
167 input_width, # type: int
168 input_height, # type: int
169 input_channels, # type: int
170 classes=1, # type: int
171 filters=32, # type: int
172 layers=9, # type: int
173 activation='leaky_relu', # type: str
174 batch_norm=True, # type: bool
175 output_activation='sigmoid', # type: str
176 kernel_initializer='he_normal', # type: str
177 attention_gates=False, # type: bool
178 attention_activation_1='relu', # type: str
179 attention_activation_2='sigmoid', # type: str
180 attention_kernel_size=1, # type: int
181 attention_strides=1, # type: int
182 attention_padding='same', # type: str
183 attention_kernel_initializer='he_normal', # type: str
184 dtype='float16', # type: str
185 data_format='channels_last', # type: str
186):
187 # type: (...) -> tfmodels.Model
188 '''
189 UNet model for 2D semantic segmentation.
191 see: https://arxiv.org/abs/1505.04597
192 see: https://arxiv.org/pdf/1411.4280.pdf
193 see: https://arxiv.org/abs/1804.03999
195 Args:
196 input_width (int): Input width.
197 input_height (int): Input height.
198 input_channels (int): Input channels.
199 classes (int, optional): Number of output classes. Default: 1.
200 filters (int, optional): Number of filters for initial con 2d block.
201 Default: 32.
202 layers (int, optional): Total number of layers. Default: 9.
203 activation (KerasTensor, optional): Activation function to be used.
204 Default: leaky_relu.
205 batch_norm (KerasTensor, optional): Use batch normalization.
206 Default: True.
207 output_activation (KerasTensor, optional): Output activation function.
208 Default: sigmoid.
209 kernel_initializer (KerasTensor, optional): Default: he_normal.
210 attention_gates (KerasTensor, optional): Use attention gates.
211 Default: False.
212 attention_activation_1 (str, optional): First activation.
213 Default: 'relu'
214 attention_activation_2 (str, optional): Second activation.
215 Default: 'sigmoid'
216 attention_kernel_size (int, optional): Kernel_size. Default: 1
217 attention_strides (int, optional): Strides. Default: 1
218 attention_padding (str, optional): Padding. Default: 'same'
219 attention_kernel_initializer (str, optional): Kernel initializer.
220 Default: 'he_normal'
221 dtype (str, optional): Model dtype. Default: float16.
222 data_format (str, optional): Model data format. Default: channels_last.
224 Raises:
225 EnforceError: If input_width is not even.
226 EnforceError: If input_height is not even.
227 EnforceError: If layers is not an odd integer greater than 2.
228 EnforceError: If input_width and layers are not compatible.
230 Returns:
231 tfmodels.Model: UNet model.
232 '''
233 # shape
234 msg = 'Input width and height must be equal, even numbers. '
235 msg += f'Given shape: ({input_width}, {input_height}).'
236 Enforce(input_width % 2, '==', 0, message=msg)
237 Enforce(input_height % 2, '==', 0, message=msg)
238 Enforce(input_width, '==', input_height, message=msg)
240 # layers
241 msg = 'Layers must be an odd integer greater than 2. '
242 msg += f'Given value: {layers}.'
243 Enforce(layers, 'instance of', int, message=msg)
244 Enforce(layers, '>=', 3, message=msg)
245 Enforce(layers % 2, '==', 1, message=msg)
247 # valid width and layers
248 msg = 'Given input_width and layers are not compatible. '
249 msg += f'Input_width: {input_width}. Layers: {layers}.'
250 Enforce(
251 unet_width_and_layers_are_valid(input_width, layers), '==', True, message=msg
252 )
253 # --------------------------------------------------------------------------
255 n = int((layers - 1) / 2)
256 encode_layers = []
258 # input layer
259 shape = (input_width, input_height, input_channels)
260 input_ = tfl.Input(shape, name='input', dtype=dtype)
262 # encode layers
263 x = input_
264 for i in range(n):
265 # conv backend of layer
266 x = conv_2d_block(
267 input_=x,
268 filters=filters,
269 batch_norm=batch_norm,
270 activation=activation,
271 kernel_initializer=kernel_initializer,
272 name=f'encode-block_{i:02d}',
273 dtype=dtype,
274 data_format=data_format,
275 )
276 encode_layers.append(x)
278 # downsample
279 name = fict.pad_layer_name(f'downsample_{i:02d}', length=PAD)
280 x = tfl.MaxPooling2D(
281 (2, 2), name=name, dtype=dtype, data_format=data_format,
282 )(x)
283 filters *= 2
285 # middle layer
286 name = fict.pad_layer_name('middle-block_00', length=PAD)
287 x = conv_2d_block(
288 input_=x,
289 filters=filters,
290 batch_norm=batch_norm,
291 activation=activation,
292 kernel_initializer=kernel_initializer,
293 name=name,
294 dtype=dtype,
295 )
297 # decode layers
298 decode_layers = list(reversed(encode_layers))
299 for i, layer in enumerate(decode_layers[:n]):
300 filters = int(filters / 2)
302 # upsample
303 name = fict.pad_layer_name(f'upsample_{i:02d}', length=PAD)
304 x = tfl.Conv2DTranspose(
305 filters=filters,
306 kernel_size=(2, 2),
307 strides=(2, 2),
308 padding='same',
309 name=name,
310 dtype=dtype,
311 data_format=data_format,
312 )(x)
314 # attention gate
315 if attention_gates:
316 name = fict.pad_layer_name(f'attention-gate_{i:02d}', length=PAD)
317 x = attention_gate_2d(
318 x,
319 layer,
320 activation_1=attention_activation_1,
321 activation_2=attention_activation_2,
322 kernel_size=attention_kernel_size,
323 strides=attention_strides,
324 padding=attention_padding,
325 kernel_initializer=attention_kernel_initializer,
326 name=name,
327 dtype=dtype,
328 data_format=data_format,
329 )
330 else:
331 name = fict.pad_layer_name(f'concat_{i:02d}', length=PAD)
332 x = tfl.concatenate([layer, x], name=name, dtype=dtype)
334 # conv backend of layer
335 x = conv_2d_block(
336 input_=x,
337 filters=filters,
338 batch_norm=batch_norm,
339 activation=activation,
340 kernel_initializer=kernel_initializer,
341 name=f'decode-block_{i:02d}',
342 dtype=dtype,
343 data_format=data_format,
344 )
346 output = tfl.Conv2D(
347 classes, (1, 1), activation=output_activation, name='output',
348 dtype=dtype, data_format=data_format,
349 )(x)
350 model = tfmodels.Model(inputs=[input_], outputs=[output])
351 return model
354# CONFIG------------------------------------------------------------------------
355class UNetConfig(pyd.BaseModel):
356 '''
357 Configuration for UNet model.
359 Attributes:
360 input_width (int): Input width.
361 input_height (int): Input height.
362 input_channels (int): Input channels.
363 classes (int, optional): Number of output classes. Default: 1.
364 filters (int, optional): Number of filters for initial con 2d block.
365 Default: 16.
366 layers (int, optional): Total number of layers. Default: 9.
367 activation (KerasTensor, optional): Activation function to be used.
368 Default: relu.
369 batch_norm (KerasTensor, optional): Use batch normalization.
370 Default: True.
371 output_activation (KerasTensor, optional): Output activation function.
372 Default: sigmoid.
373 kernel_initializer (KerasTensor, optional): Default: he_normal.
374 attention_gates (KerasTensor, optional): Use attention gates.
375 Default: False.
376 attention_activation_1 (str, optional): First activation.
377 Default: 'relu'
378 attention_activation_2 (str, optional): Second activation.
379 Default: 'sigmoid'
380 attention_kernel_size (int, optional): Kernel_size. Default: 1
381 attention_strides (int, optional): Strides. Default: 1
382 attention_padding (str, optional): Padding. Default: 'same'
383 attention_kernel_initializer (str, optional): Kernel initializer.
384 Default: 'he_normal'
385 '''
386 input_width: Annotated[int, pyd.Field(ge=1)]
387 input_height: Annotated[int, pyd.Field(ge=1)]
388 input_channels: Annotated[int, pyd.Field(ge=1)]
389 classes: Annotated[int, pyd.Field(ge=1)] = 1
390 filters: Annotated[int, pyd.Field(ge=1)] = 16
391 layers: Annotated[int, pyd.Field(ge=3), pyd.AfterValidator(vd.is_odd)] = 9
392 activation: str = 'relu'
393 batch_norm: bool = True
394 output_activation: str = 'sigmoid'
395 kernel_initializer: str = 'he_normal'
396 attention_gates: bool = False
397 attention_activation_1: str = 'relu'
398 attention_activation_2: str = 'sigmoid'
399 attention_kernel_size: Annotated[int, pyd.Field(ge=1)] = 1
400 attention_strides: Annotated[int, pyd.Field(ge=1)] = 1
401 attention_padding: Annotated[str, pyd.AfterValidator(vd.is_padding)] = 'same'
402 attention_kernel_initializer: str = 'he_normal'
403 dtype: str = 'float16'
404 data_format: str = 'channels_last'
407# PIPELINE----------------------------------------------------------------------
408class UNetPipeline(ficp.PipelineBase):
409 def model_config(self):
410 # type: () -> type[UNetConfig]
411 return UNetConfig
413 def model_func(self):
414 # type: () -> tfmodels.Model
415 return get_unet_model