Coverage for /home/ubuntu/cv-depot/python/cv_depot/core/image.py: 99%
292 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 20:26 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 20:26 +0000
1from typing import Any, Optional, Tuple, Union # noqa F401
2from numpy.typing import NDArray # noqa F401
3from cv_depot.core.types import Filepath # noqa F401
5from copy import deepcopy
6from itertools import combinations, chain
7from pathlib import Path
8import os
9import re
11from lunchbox.enforce import Enforce
12from openexr_tools.enum import ImageCodec
13import numpy as np
14import openexr_tools.tools as exrtools
15import PIL.Image as pil
17from cv_depot.core.enum import BitDepth, ImageFormat
18from cv_depot.core.viewer import ImageViewer
19import cv_depot.core.tools as cvt
20# ------------------------------------------------------------------------------
23def _has_super_darks(image):
24 # type: (Image) -> bool
25 '''
26 Determines if given image has values below 0.0
28 Args:
29 image (Image): Image instance.
31 Raises:
32 EnforceError: If image is not an Image instance.
34 Returns:
35 bool: Presence of super darks.
36 '''
37 Enforce(image, 'instance of', Image)
38 return bool(image.data.min() < 0.0)
41def _has_super_brights(image):
42 # type: (Image) -> bool
43 '''
44 Determines if given image has values above 1.0
46 Args:
47 image (Image): Image instance.
49 Raises:
50 EnforceError: If image is not an Image instance.
52 Returns:
53 bool: Presence of super brights.
54 '''
55 Enforce(image, 'instance of', Image)
56 return bool(image.data.max() > 1.0)
59class Image():
60 '''
61 Class for reading, writing, converting and displaying properties of images.
62 '''
63 @staticmethod
64 def from_array(array):
65 # type: (NDArray) -> Image
66 '''
67 Construct an Image instance from a given numpy array.
69 Args:
70 array (numpy.NDArray): Numpy array.
72 Returns:
73 Image: Image instance of given numpy array.
74 '''
75 # enforce bit depth compliance
76 BitDepth.from_dtype(array.dtype)
77 return Image(array.copy(), {}, None, allow=True)
79 @staticmethod
80 def from_pil(image):
81 # type: (pil.Image) -> Image
82 '''
83 Construct an Image instance from a given PIL Image.
85 Args:
86 image (pil.Image): PIL Image.
88 Returns:
89 Image: Image instance of a given PIL Image.
90 '''
91 return Image.from_array(np.array(image))
93 @staticmethod
94 def read(filepath):
95 # type: (Filepath) -> Image
96 '''
97 Constructs an Image instance given a full path to an image file.
99 Args:
100 filepath (str or Path): Image filepath.
102 Raises:
103 FileNotFoundError: If file could not be found on disk.
104 TypeError: If filepath is not a str or Path.
106 Returns:
107 Image: Image instance of given file.
108 '''
109 metadata = {} # type: dict[str, Any]
110 format_ = None
112 if isinstance(filepath, Path):
113 filepath = filepath.absolute().as_posix()
115 if isinstance(filepath, str):
116 if not os.path.exists(filepath):
117 msg = f'{filepath} does not exist.'
118 raise FileNotFoundError(msg)
120 _, ext = os.path.splitext(filepath)
121 format_ = ImageFormat.from_extension(ext)
123 if format_ is ImageFormat.EXR:
124 data, metadata = exrtools.read_exr(filepath)
125 else:
126 data = np.asarray(pil.open(filepath))
128 else:
129 msg = f'Object of type {filepath.__class__.__name__} '
130 msg += 'is not a str or Path.'
131 raise TypeError(msg)
133 return Image(data, metadata, format_, allow=True)
135 def __init__(self, data, metadata={}, format_=None, allow=False):
136 # type: (NDArray, dict[str, Any], Optional[ImageFormat], bool) -> None
137 '''
138 This constructor should not be called directly except internally and in
139 testing.
141 Args:
142 data (numpy.NDArray): Image.
143 metadata (dict, optional): Image metadata. Default: {}.
144 format_ (ImageFormat, optional): Format of image. Default: None.
145 allow (bool, optional): Whether to allow construction using init.
146 Default: False.
148 Raises:
149 AttributeError: If image data dimensions are not 2 or 3.
151 Returns:
152 Image: Image instance.
153 '''
154 if not allow:
155 msg = "Please call one of Image's static constructors to create an "
156 msg += 'instance. Options include: read, from_array.'
157 raise NotImplementedError(msg)
159 # ensure data has 3 dimensions
160 shape = data.shape
161 dims = len(shape)
162 if dims > 3 or dims < 2:
163 msg = f'Illegal number of dimensions for image data. {dims} not in '
164 msg += '[2, 3].'
165 raise AttributeError(msg)
167 if dims == 2:
168 data = data.reshape((*shape, 1))
170 self._data = data
171 self.metadata = metadata
172 self.format = format_
173 # --------------------------------------------------------------------------
175 def _repr(self):
176 # type: () -> str
177 fmat = str(None)
178 if self.format is not None:
179 fmat = self.format.name
181 return f'''
182 WIDTH: {self.width}
183 HEIGHT: {self.height}
184NUM_CHANNELS: {self.num_channels}
185 BIT_DEPTH: {self.bit_depth.name}
186 FORMAT: {fmat}'''[1:]
188 def _repr_html_(self):
189 # type: () -> None
190 '''
191 Creates a HTML representation of image data.
192 '''
193 ImageViewer(self).show()
195 def _repr_png(self):
196 # type: () -> Optional[bytes]
197 '''
198 Creates a PNG representation of image data.
200 Returns:
201 str: PNG.
202 '''
203 this = self
204 if _has_super_brights(self) or _has_super_darks(self):
205 this = self.to_unit_space()
206 output = this.to_bit_depth(BitDepth.UINT8).to_pil()._repr_png_()
207 return output
209 def _string_to_channels(self, string):
210 # type: (str) -> list
211 '''
212 Converts string to list of channels.
214 Args:
215 string (str): String representation of channels.
217 Returns:
218 list: List of channels.
219 '''
220 # special rgba short circuit
221 combos = [list(combinations('rgba', i)) for i in range(1, 5)] # type: Any
222 combos = list(map(set, chain(*combos)))
223 if set(string) in combos:
224 return list(string)
226 # if channels is actually a layer name
227 if string in self.channel_layers:
228 found = list(filter(
229 lambda x: re.search(string + r'\..+', str(x)),
230 self.channels
231 ))
232 if found != []:
233 # found channels that matched [layer-name].[channel] pattern
234 return found
235 return list(string)
236 return [string]
238 def __getitem__(self, indices):
239 # type: (Union[int, tuple, list, slice, str]) -> Image
240 '''
241 Gets slice of image data. Indices are given in the order:
242 width, height, channel.
244 Args:
245 indices (int, tuple, list, slice, str): Slice of image data.
247 Raises:
248 IndexError: If number of indices provided is greater than 3.
249 IndexError: If channel given is illegal.
250 IndexError: If three lists are given as indices.
252 Returns:
253 Image: Image slice.
254 '''
255 if not isinstance(indices, tuple) or isinstance(indices, list):
256 indices = [indices]
258 size = len(indices)
259 if size > 3:
260 msg = f'Number of dimensions provided: {size}, is greater than 3.'
261 raise IndexError(msg)
263 # convert indices to triplet of slices
264 columns = slice(None, None) # type: Any
265 rows = slice(None, None) # type: Any
266 channels = slice(None, None) # type: Any
267 if size == 3:
268 columns, rows, channels = indices
269 elif size == 2:
270 columns, rows = indices
271 else:
272 columns = indices[0]
274 # convert channels to list of indices
275 channel_meta = self.metadata.get('channels', [])
276 if channels.__class__.__name__ in ['str', 'tuple', 'list']:
277 if isinstance(channels, str):
278 channels = self._string_to_channels(channels)
279 channel_meta = channels
280 chans = []
281 for channel in channels:
282 if isinstance(channel, str):
283 if channel not in self.channels:
284 msg = f'{channel} is not a legal channel name.'
285 raise IndexError(msg)
286 channel = self.channels.index(channel)
287 chans.append(channel)
289 if len(chans) == 1:
290 chans = chans[0]
291 channels = chans
293 # coerce to list for simpler logic
294 if isinstance(columns, tuple):
295 columns = list(columns)
296 if isinstance(rows, tuple):
297 rows = list(rows)
299 types = [
300 columns.__class__.__name__,
301 rows.__class__.__name__,
302 channels.__class__.__name__,
303 ]
304 if types == ['list', 'list', 'list']:
305 msg = 'Three lists are not acceptable as indices.'
306 raise IndexError(msg)
308 if isinstance(channels, slice):
309 channel_meta = self.channels[channels]
311 data = self._data[rows, columns, channels]
312 metadata = deepcopy(self.metadata)
313 metadata['channels'] = channel_meta
314 return Image(data, metadata=metadata, format_=self.format, allow=True)
315 # --------------------------------------------------------------------------
317 def set_channels(self, channels):
318 # type: (list[Union[str, int]]) -> Image
319 '''
320 Set's channels names.
322 Args:
323 channels (list[str or int]): List of channel names:
325 Raises:
326 ValueError: If number of channels given doesn't not equal data
327 shape.
328 ValueError: If duplicate channel names found.
330 Returns:
331 Image: self.
332 '''
333 if len(channels) != self.num_channels:
334 msg = 'Number of channels given does not equal last dimension size.'
335 msg += f' {len(channels)} != {self.num_channels}.'
336 raise ValueError(msg)
338 uniq = set(channels)
339 if len(uniq) < len(channels):
340 for c in uniq:
341 channels.remove(c)
342 msg = f'Duplicate channel names found: {channels}.'
343 raise ValueError(msg)
345 metadata = deepcopy(self.metadata)
346 metadata['channels'] = channels
347 return Image(
348 self._data.copy(),
349 metadata=metadata,
350 format_=self.format,
351 allow=True,
352 )
354 def write(self, filepath, codec=ImageCodec.PIZ):
355 # type: (Filepath, ImageCodec) -> None
356 '''
357 Write image to file.
359 Args:
360 filepath (str or Path): Full path to image file.
361 codec (ImageCodec, optional): EXR compression scheme to be used.
362 Default: ImageCodec.PIZ.
364 Raises:
365 TypeError: If format does not support instance bit depth.
366 AttributeError: If format does not support the number of channels in
367 instance.
368 '''
369 if isinstance(filepath, Path):
370 filepath = filepath.absolute().as_posix()
372 _, ext_ = os.path.splitext(filepath)
373 ext = ImageFormat.from_extension(ext_)
375 # ensure format is compatible with image data
376 if self.bit_depth not in ext.bit_depths:
377 msg = f'{ext.name} cannot be written with {self.bit_depth.name}'
378 msg += ' data.'
379 raise TypeError(msg)
381 if self.num_channels > ext.max_channels:
382 msg = f'{ext.name} cannot be written with {self.num_channels} '
383 msg += f'channels. Max channels supported: {ext.max_channels}.'
384 raise AttributeError(msg)
386 # write data
387 if ext is ImageFormat.EXR:
388 metadata = self.metadata
389 metadata['channels'] = self.channels
390 exrtools.write_exr(filepath, self._data, metadata, codec)
392 else:
393 pil.fromarray(self._data).save(filepath, format=ext.name)
395 def to_bit_depth(self, bit_depth):
396 # type: (BitDepth) -> Image
397 '''
398 Convert image to given bit depth.
399 Warning: Numpy's conversions for INT8 are bizarre.
401 Args:
402 bit_depth (BitDepth): Target bit depth.
404 Raises:
405 ValueError: If converting from float to 8-bit and values exceed 1.
406 ValueError: If converting from float to 8-bit and values less than 0.
408 Returns:
409 Image: New Image instance at given bit depth.
410 '''
411 image = self._data
412 src = self.bit_depth
413 tgt = bit_depth
415 if src is tgt:
416 return self
418 elif src is BitDepth.UINT8 and tgt.type_ is float:
419 image = image.astype(tgt.dtype) / 255
421 elif src.type_ is float and tgt.bits == 8:
422 if _has_super_darks(self):
423 msg = f'Image has values below 0. Min value: {image.min()}'
424 raise ValueError(msg)
426 if _has_super_brights(self):
427 msg = f'Image has values above 1. Max value: {image.max()}'
428 raise ValueError(msg)
430 image = (image * 255).astype(tgt.dtype)
432 else:
433 image = image.astype(tgt.dtype)
435 metadata = deepcopy(self.metadata)
436 return Image(image, metadata=metadata, format_=self.format, allow=True)
438 def to_unit_space(self):
439 # type: () -> Image
440 '''
441 Normalizes image to [0, 1] range.
443 Returns:
444 Image: Normalized image.
445 '''
446 data = self.to_bit_depth(BitDepth.FLOAT32)._data.copy()
447 max_, min_ = data.max(), data.min()
448 data = (data - min_) / (max_ - min_)
449 data = Image.from_array(data).to_bit_depth(self.bit_depth)._data
450 metadata = deepcopy(self.metadata)
451 return Image(data, metadata=metadata, format_=self.format, allow=True)
453 def to_array(self):
454 # type: () -> NDArray
455 '''
456 Returns numpy array.
458 Returns:
459 numpy.NDArray: Image as numpy array.
460 '''
461 return self.data
463 def to_pil(self):
464 # type: () -> pil.Image
465 '''
466 Returns pil.Image.
468 Returns:
469 pil: Image as pil.Image.
470 '''
471 if self.num_channels == 1:
472 mode = 'L'
473 elif self.num_channels == 3:
474 mode = 'RGB'
475 elif self.num_channels == 4:
476 mode = 'RGBA'
477 else:
478 raise ValueError('PIL only accepts image with 1, 3 or 4 channels.')
479 return pil.fromarray(self.data, mode=mode)
481 def compare(self, image, content=False, diff_only=False):
482 # type: (Image, bool, bool) -> dict[str, Any]
483 '''
484 Compare this image with a given image.
486 Args:
487 image (Image): Image to compare.
488 content (bool, optional): If True, compare data. Default: False.
489 diff_only (bool, optional): If True, only return the keys with
490 differing values. Default: False.
492 Raises:
493 EnforceError: If image is not an Image instance.
494 ValueError: IF content is True and images cannot be compared.
496 Returns:
497 dict: A dictionary of comparisons.
498 '''
499 msg = 'Image must be an instance of Image.'
500 Enforce(image, 'instance of', Image, message=msg)
501 # ----------------------------------------------------------------------
503 a = self.info
504 b = image.info
505 output = {} # type: dict[str, Any]
506 for k, v in a.items():
507 output[k] = (v, b.get(k, None))
508 for k, v in b.items():
509 if k not in output.keys():
510 output[k] = (a.get(k, None), v)
512 if diff_only:
513 for k, v in list(output.items()):
514 if v[0] == v[1]:
515 del output[k]
517 if content:
518 x = self.to_bit_depth(BitDepth.FLOAT16).data
519 y = image.to_bit_depth(BitDepth.FLOAT16).data
521 try:
522 diff = float(abs(x - y).mean())
523 except ValueError as e:
524 raise ValueError(f'Cannot compare images: {e}')
526 output['mean_content_difference'] = diff
527 if diff_only and diff == 0:
528 del output['mean_content_difference']
530 return output
532 def __eq__(self, image):
533 # type: (object) -> bool
534 '''
535 Compare this image with a given image.
537 Returns:
538 bool: True if images are equal.
539 '''
540 return self.compare(image, content=True, diff_only=True) == {} # type: ignore
541 # --------------------------------------------------------------------------
543 @property
544 def data(self):
545 # type: () -> NDArray
546 '''
547 numpy.NDArray: Image data.
548 '''
549 if self.num_channels == 1:
550 return np.squeeze(self._data, axis=2)
551 return self._data
553 @property
554 def info(self):
555 # type: () -> dict[str, Any]
556 '''
557 dict: A dictionary of all information about the Image instance.
558 '''
559 output = dict(
560 width=self.width,
561 height=self.height,
562 channels=self.channels,
563 num_channels=self.num_channels,
564 bit_depth=self.bit_depth.name,
565 dtype=self.bit_depth.dtype,
566 bits=self.bit_depth.bits,
567 signed=self.bit_depth.signed,
568 type=self.bit_depth.type_,
569 format_extension=None,
570 format_bit_depths=None,
571 format_channels=None,
572 format_max_channels=None,
573 format_custom_metadata=None,
574 )
575 if self.format is not None:
576 fmat = dict(
577 format_extension=self.extension,
578 format_bit_depths=self.format.bit_depths,
579 format_channels=self.format.channels,
580 format_max_channels=self.format.max_channels,
581 format_custom_metadata=self.format.custom_metadata,
582 )
583 output.update(fmat)
584 return output
586 @property
587 def shape(self):
588 # type: () -> Tuple[int, int, int]
589 '''
590 tuple[int]: (width, height, channels) of image.
591 '''
592 return (self.width, self.height, self.num_channels)
594 @property
595 def width(self):
596 # type: () -> int
597 '''
598 int: Width of image.
599 '''
600 return self._data.shape[1]
602 @property
603 def height(self):
604 # type: () -> int
605 '''
606 int: Height of image.
607 '''
608 return self._data.shape[0]
610 @property
611 def width_and_height(self):
612 # type: () -> Tuple[int, int]
613 '''
614 tupe[int]: (width, height) of image.
615 '''
616 return (self.width, self.height)
618 @property
619 def channels(self):
620 # type: () -> list[Union[str, int]]
621 '''
622 list[str or int]: List of channel names.
623 '''
624 if 'channels' in self.metadata:
625 return self.metadata['channels']
626 return cvt.get_channels_from_array(self._data)
628 @property
629 def num_channels(self):
630 # type: () -> int
631 '''
632 int: Number of channels in image.
633 '''
634 return len(self.channels)
636 @property
637 def max_channels(self):
638 # type: () -> Optional[int]
639 '''
640 int: Maximum number of channels supported by image format.
641 '''
642 if self.format is None:
643 return None
644 return self.format.max_channels
646 @property
647 def channel_layers(self):
648 # type: () -> list[str]
649 '''
650 list[str]: List of channel layers.
651 '''
652 channels = [str(x) for x in self.channels]
653 with_layer = list(filter(lambda x: '.' in str(x), channels)) # type: list[str]
654 wo_layer_name = list(filter(lambda x: '.' not in str(x), channels))
656 # break out channels without layer names into groups of 4
657 len_ = len(wo_layer_name)
658 layers = []
659 for idx in range(0, len_, 4):
660 layer_ = wo_layer_name[idx:min(idx + 4, len_)]
661 layer = ''.join(layer_)
662 layers.append(layer)
664 # append all unique layers of channels with layer names
665 for chan in with_layer:
666 layer = ''.join(chan.split('.')[0])
667 if layer not in layers:
668 layers.append(layer)
669 return layers
671 @property
672 def bit_depth(self):
673 # type: () -> BitDepth
674 '''
675 BitDepth: Bit depth of image.
676 '''
677 return BitDepth.from_dtype(self._data.dtype)
679 @property
680 def extension(self):
681 # type: () -> Optional[str]
682 '''
683 str: Image format extension.
684 '''
685 if self.format is None:
686 return None
687 return self.format.extension