Coverage for /home/ubuntu/cv-depot/python/cv_depot/core/image.py: 99%

292 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 20:26 +0000

1from typing import Any, Optional, Tuple, Union # noqa F401 

2from numpy.typing import NDArray # noqa F401 

3from cv_depot.core.types import Filepath # noqa F401 

4 

5from copy import deepcopy 

6from itertools import combinations, chain 

7from pathlib import Path 

8import os 

9import re 

10 

11from lunchbox.enforce import Enforce 

12from openexr_tools.enum import ImageCodec 

13import numpy as np 

14import openexr_tools.tools as exrtools 

15import PIL.Image as pil 

16 

17from cv_depot.core.enum import BitDepth, ImageFormat 

18from cv_depot.core.viewer import ImageViewer 

19import cv_depot.core.tools as cvt 

20# ------------------------------------------------------------------------------ 

21 

22 

23def _has_super_darks(image): 

24 # type: (Image) -> bool 

25 ''' 

26 Determines if given image has values below 0.0 

27 

28 Args: 

29 image (Image): Image instance. 

30 

31 Raises: 

32 EnforceError: If image is not an Image instance. 

33 

34 Returns: 

35 bool: Presence of super darks. 

36 ''' 

37 Enforce(image, 'instance of', Image) 

38 return bool(image.data.min() < 0.0) 

39 

40 

41def _has_super_brights(image): 

42 # type: (Image) -> bool 

43 ''' 

44 Determines if given image has values above 1.0 

45 

46 Args: 

47 image (Image): Image instance. 

48 

49 Raises: 

50 EnforceError: If image is not an Image instance. 

51 

52 Returns: 

53 bool: Presence of super brights. 

54 ''' 

55 Enforce(image, 'instance of', Image) 

56 return bool(image.data.max() > 1.0) 

57 

58 

59class Image(): 

60 ''' 

61 Class for reading, writing, converting and displaying properties of images. 

62 ''' 

63 @staticmethod 

64 def from_array(array): 

65 # type: (NDArray) -> Image 

66 ''' 

67 Construct an Image instance from a given numpy array. 

68 

69 Args: 

70 array (numpy.NDArray): Numpy array. 

71 

72 Returns: 

73 Image: Image instance of given numpy array. 

74 ''' 

75 # enforce bit depth compliance 

76 BitDepth.from_dtype(array.dtype) 

77 return Image(array.copy(), {}, None, allow=True) 

78 

79 @staticmethod 

80 def from_pil(image): 

81 # type: (pil.Image) -> Image 

82 ''' 

83 Construct an Image instance from a given PIL Image. 

84 

85 Args: 

86 image (pil.Image): PIL Image. 

87 

88 Returns: 

89 Image: Image instance of a given PIL Image. 

90 ''' 

91 return Image.from_array(np.array(image)) 

92 

93 @staticmethod 

94 def read(filepath): 

95 # type: (Filepath) -> Image 

96 ''' 

97 Constructs an Image instance given a full path to an image file. 

98 

99 Args: 

100 filepath (str or Path): Image filepath. 

101 

102 Raises: 

103 FileNotFoundError: If file could not be found on disk. 

104 TypeError: If filepath is not a str or Path. 

105 

106 Returns: 

107 Image: Image instance of given file. 

108 ''' 

109 metadata = {} # type: dict[str, Any] 

110 format_ = None 

111 

112 if isinstance(filepath, Path): 

113 filepath = filepath.absolute().as_posix() 

114 

115 if isinstance(filepath, str): 

116 if not os.path.exists(filepath): 

117 msg = f'{filepath} does not exist.' 

118 raise FileNotFoundError(msg) 

119 

120 _, ext = os.path.splitext(filepath) 

121 format_ = ImageFormat.from_extension(ext) 

122 

123 if format_ is ImageFormat.EXR: 

124 data, metadata = exrtools.read_exr(filepath) 

125 else: 

126 data = np.asarray(pil.open(filepath)) 

127 

128 else: 

129 msg = f'Object of type {filepath.__class__.__name__} ' 

130 msg += 'is not a str or Path.' 

131 raise TypeError(msg) 

132 

133 return Image(data, metadata, format_, allow=True) 

134 

135 def __init__(self, data, metadata={}, format_=None, allow=False): 

136 # type: (NDArray, dict[str, Any], Optional[ImageFormat], bool) -> None 

137 ''' 

138 This constructor should not be called directly except internally and in 

139 testing. 

140 

141 Args: 

142 data (numpy.NDArray): Image. 

143 metadata (dict, optional): Image metadata. Default: {}. 

144 format_ (ImageFormat, optional): Format of image. Default: None. 

145 allow (bool, optional): Whether to allow construction using init. 

146 Default: False. 

147 

148 Raises: 

149 AttributeError: If image data dimensions are not 2 or 3. 

150 

151 Returns: 

152 Image: Image instance. 

153 ''' 

154 if not allow: 

155 msg = "Please call one of Image's static constructors to create an " 

156 msg += 'instance. Options include: read, from_array.' 

157 raise NotImplementedError(msg) 

158 

159 # ensure data has 3 dimensions 

160 shape = data.shape 

161 dims = len(shape) 

162 if dims > 3 or dims < 2: 

163 msg = f'Illegal number of dimensions for image data. {dims} not in ' 

164 msg += '[2, 3].' 

165 raise AttributeError(msg) 

166 

167 if dims == 2: 

168 data = data.reshape((*shape, 1)) 

169 

170 self._data = data 

171 self.metadata = metadata 

172 self.format = format_ 

173 # -------------------------------------------------------------------------- 

174 

175 def _repr(self): 

176 # type: () -> str 

177 fmat = str(None) 

178 if self.format is not None: 

179 fmat = self.format.name 

180 

181 return f''' 

182 WIDTH: {self.width} 

183 HEIGHT: {self.height} 

184NUM_CHANNELS: {self.num_channels} 

185 BIT_DEPTH: {self.bit_depth.name} 

186 FORMAT: {fmat}'''[1:] 

187 

188 def _repr_html_(self): 

189 # type: () -> None 

190 ''' 

191 Creates a HTML representation of image data. 

192 ''' 

193 ImageViewer(self).show() 

194 

195 def _repr_png(self): 

196 # type: () -> Optional[bytes] 

197 ''' 

198 Creates a PNG representation of image data. 

199 

200 Returns: 

201 str: PNG. 

202 ''' 

203 this = self 

204 if _has_super_brights(self) or _has_super_darks(self): 

205 this = self.to_unit_space() 

206 output = this.to_bit_depth(BitDepth.UINT8).to_pil()._repr_png_() 

207 return output 

208 

209 def _string_to_channels(self, string): 

210 # type: (str) -> list 

211 ''' 

212 Converts string to list of channels. 

213 

214 Args: 

215 string (str): String representation of channels. 

216 

217 Returns: 

218 list: List of channels. 

219 ''' 

220 # special rgba short circuit 

221 combos = [list(combinations('rgba', i)) for i in range(1, 5)] # type: Any 

222 combos = list(map(set, chain(*combos))) 

223 if set(string) in combos: 

224 return list(string) 

225 

226 # if channels is actually a layer name 

227 if string in self.channel_layers: 

228 found = list(filter( 

229 lambda x: re.search(string + r'\..+', str(x)), 

230 self.channels 

231 )) 

232 if found != []: 

233 # found channels that matched [layer-name].[channel] pattern 

234 return found 

235 return list(string) 

236 return [string] 

237 

238 def __getitem__(self, indices): 

239 # type: (Union[int, tuple, list, slice, str]) -> Image 

240 ''' 

241 Gets slice of image data. Indices are given in the order: 

242 width, height, channel. 

243 

244 Args: 

245 indices (int, tuple, list, slice, str): Slice of image data. 

246 

247 Raises: 

248 IndexError: If number of indices provided is greater than 3. 

249 IndexError: If channel given is illegal. 

250 IndexError: If three lists are given as indices. 

251 

252 Returns: 

253 Image: Image slice. 

254 ''' 

255 if not isinstance(indices, tuple) or isinstance(indices, list): 

256 indices = [indices] 

257 

258 size = len(indices) 

259 if size > 3: 

260 msg = f'Number of dimensions provided: {size}, is greater than 3.' 

261 raise IndexError(msg) 

262 

263 # convert indices to triplet of slices 

264 columns = slice(None, None) # type: Any 

265 rows = slice(None, None) # type: Any 

266 channels = slice(None, None) # type: Any 

267 if size == 3: 

268 columns, rows, channels = indices 

269 elif size == 2: 

270 columns, rows = indices 

271 else: 

272 columns = indices[0] 

273 

274 # convert channels to list of indices 

275 channel_meta = self.metadata.get('channels', []) 

276 if channels.__class__.__name__ in ['str', 'tuple', 'list']: 

277 if isinstance(channels, str): 

278 channels = self._string_to_channels(channels) 

279 channel_meta = channels 

280 chans = [] 

281 for channel in channels: 

282 if isinstance(channel, str): 

283 if channel not in self.channels: 

284 msg = f'{channel} is not a legal channel name.' 

285 raise IndexError(msg) 

286 channel = self.channels.index(channel) 

287 chans.append(channel) 

288 

289 if len(chans) == 1: 

290 chans = chans[0] 

291 channels = chans 

292 

293 # coerce to list for simpler logic 

294 if isinstance(columns, tuple): 

295 columns = list(columns) 

296 if isinstance(rows, tuple): 

297 rows = list(rows) 

298 

299 types = [ 

300 columns.__class__.__name__, 

301 rows.__class__.__name__, 

302 channels.__class__.__name__, 

303 ] 

304 if types == ['list', 'list', 'list']: 

305 msg = 'Three lists are not acceptable as indices.' 

306 raise IndexError(msg) 

307 

308 if isinstance(channels, slice): 

309 channel_meta = self.channels[channels] 

310 

311 data = self._data[rows, columns, channels] 

312 metadata = deepcopy(self.metadata) 

313 metadata['channels'] = channel_meta 

314 return Image(data, metadata=metadata, format_=self.format, allow=True) 

315 # -------------------------------------------------------------------------- 

316 

317 def set_channels(self, channels): 

318 # type: (list[Union[str, int]]) -> Image 

319 ''' 

320 Set's channels names. 

321 

322 Args: 

323 channels (list[str or int]): List of channel names: 

324 

325 Raises: 

326 ValueError: If number of channels given doesn't not equal data 

327 shape. 

328 ValueError: If duplicate channel names found. 

329 

330 Returns: 

331 Image: self. 

332 ''' 

333 if len(channels) != self.num_channels: 

334 msg = 'Number of channels given does not equal last dimension size.' 

335 msg += f' {len(channels)} != {self.num_channels}.' 

336 raise ValueError(msg) 

337 

338 uniq = set(channels) 

339 if len(uniq) < len(channels): 

340 for c in uniq: 

341 channels.remove(c) 

342 msg = f'Duplicate channel names found: {channels}.' 

343 raise ValueError(msg) 

344 

345 metadata = deepcopy(self.metadata) 

346 metadata['channels'] = channels 

347 return Image( 

348 self._data.copy(), 

349 metadata=metadata, 

350 format_=self.format, 

351 allow=True, 

352 ) 

353 

354 def write(self, filepath, codec=ImageCodec.PIZ): 

355 # type: (Filepath, ImageCodec) -> None 

356 ''' 

357 Write image to file. 

358 

359 Args: 

360 filepath (str or Path): Full path to image file. 

361 codec (ImageCodec, optional): EXR compression scheme to be used. 

362 Default: ImageCodec.PIZ. 

363 

364 Raises: 

365 TypeError: If format does not support instance bit depth. 

366 AttributeError: If format does not support the number of channels in 

367 instance. 

368 ''' 

369 if isinstance(filepath, Path): 

370 filepath = filepath.absolute().as_posix() 

371 

372 _, ext_ = os.path.splitext(filepath) 

373 ext = ImageFormat.from_extension(ext_) 

374 

375 # ensure format is compatible with image data 

376 if self.bit_depth not in ext.bit_depths: 

377 msg = f'{ext.name} cannot be written with {self.bit_depth.name}' 

378 msg += ' data.' 

379 raise TypeError(msg) 

380 

381 if self.num_channels > ext.max_channels: 

382 msg = f'{ext.name} cannot be written with {self.num_channels} ' 

383 msg += f'channels. Max channels supported: {ext.max_channels}.' 

384 raise AttributeError(msg) 

385 

386 # write data 

387 if ext is ImageFormat.EXR: 

388 metadata = self.metadata 

389 metadata['channels'] = self.channels 

390 exrtools.write_exr(filepath, self._data, metadata, codec) 

391 

392 else: 

393 pil.fromarray(self._data).save(filepath, format=ext.name) 

394 

395 def to_bit_depth(self, bit_depth): 

396 # type: (BitDepth) -> Image 

397 ''' 

398 Convert image to given bit depth. 

399 Warning: Numpy's conversions for INT8 are bizarre. 

400 

401 Args: 

402 bit_depth (BitDepth): Target bit depth. 

403 

404 Raises: 

405 ValueError: If converting from float to 8-bit and values exceed 1. 

406 ValueError: If converting from float to 8-bit and values less than 0. 

407 

408 Returns: 

409 Image: New Image instance at given bit depth. 

410 ''' 

411 image = self._data 

412 src = self.bit_depth 

413 tgt = bit_depth 

414 

415 if src is tgt: 

416 return self 

417 

418 elif src is BitDepth.UINT8 and tgt.type_ is float: 

419 image = image.astype(tgt.dtype) / 255 

420 

421 elif src.type_ is float and tgt.bits == 8: 

422 if _has_super_darks(self): 

423 msg = f'Image has values below 0. Min value: {image.min()}' 

424 raise ValueError(msg) 

425 

426 if _has_super_brights(self): 

427 msg = f'Image has values above 1. Max value: {image.max()}' 

428 raise ValueError(msg) 

429 

430 image = (image * 255).astype(tgt.dtype) 

431 

432 else: 

433 image = image.astype(tgt.dtype) 

434 

435 metadata = deepcopy(self.metadata) 

436 return Image(image, metadata=metadata, format_=self.format, allow=True) 

437 

438 def to_unit_space(self): 

439 # type: () -> Image 

440 ''' 

441 Normalizes image to [0, 1] range. 

442 

443 Returns: 

444 Image: Normalized image. 

445 ''' 

446 data = self.to_bit_depth(BitDepth.FLOAT32)._data.copy() 

447 max_, min_ = data.max(), data.min() 

448 data = (data - min_) / (max_ - min_) 

449 data = Image.from_array(data).to_bit_depth(self.bit_depth)._data 

450 metadata = deepcopy(self.metadata) 

451 return Image(data, metadata=metadata, format_=self.format, allow=True) 

452 

453 def to_array(self): 

454 # type: () -> NDArray 

455 ''' 

456 Returns numpy array. 

457 

458 Returns: 

459 numpy.NDArray: Image as numpy array. 

460 ''' 

461 return self.data 

462 

463 def to_pil(self): 

464 # type: () -> pil.Image 

465 ''' 

466 Returns pil.Image. 

467 

468 Returns: 

469 pil: Image as pil.Image. 

470 ''' 

471 if self.num_channels == 1: 

472 mode = 'L' 

473 elif self.num_channels == 3: 

474 mode = 'RGB' 

475 elif self.num_channels == 4: 

476 mode = 'RGBA' 

477 else: 

478 raise ValueError('PIL only accepts image with 1, 3 or 4 channels.') 

479 return pil.fromarray(self.data, mode=mode) 

480 

481 def compare(self, image, content=False, diff_only=False): 

482 # type: (Image, bool, bool) -> dict[str, Any] 

483 ''' 

484 Compare this image with a given image. 

485 

486 Args: 

487 image (Image): Image to compare. 

488 content (bool, optional): If True, compare data. Default: False. 

489 diff_only (bool, optional): If True, only return the keys with 

490 differing values. Default: False. 

491 

492 Raises: 

493 EnforceError: If image is not an Image instance. 

494 ValueError: IF content is True and images cannot be compared. 

495 

496 Returns: 

497 dict: A dictionary of comparisons. 

498 ''' 

499 msg = 'Image must be an instance of Image.' 

500 Enforce(image, 'instance of', Image, message=msg) 

501 # ---------------------------------------------------------------------- 

502 

503 a = self.info 

504 b = image.info 

505 output = {} # type: dict[str, Any] 

506 for k, v in a.items(): 

507 output[k] = (v, b.get(k, None)) 

508 for k, v in b.items(): 

509 if k not in output.keys(): 

510 output[k] = (a.get(k, None), v) 

511 

512 if diff_only: 

513 for k, v in list(output.items()): 

514 if v[0] == v[1]: 

515 del output[k] 

516 

517 if content: 

518 x = self.to_bit_depth(BitDepth.FLOAT16).data 

519 y = image.to_bit_depth(BitDepth.FLOAT16).data 

520 

521 try: 

522 diff = float(abs(x - y).mean()) 

523 except ValueError as e: 

524 raise ValueError(f'Cannot compare images: {e}') 

525 

526 output['mean_content_difference'] = diff 

527 if diff_only and diff == 0: 

528 del output['mean_content_difference'] 

529 

530 return output 

531 

532 def __eq__(self, image): 

533 # type: (object) -> bool 

534 ''' 

535 Compare this image with a given image. 

536 

537 Returns: 

538 bool: True if images are equal. 

539 ''' 

540 return self.compare(image, content=True, diff_only=True) == {} # type: ignore 

541 # -------------------------------------------------------------------------- 

542 

543 @property 

544 def data(self): 

545 # type: () -> NDArray 

546 ''' 

547 numpy.NDArray: Image data. 

548 ''' 

549 if self.num_channels == 1: 

550 return np.squeeze(self._data, axis=2) 

551 return self._data 

552 

553 @property 

554 def info(self): 

555 # type: () -> dict[str, Any] 

556 ''' 

557 dict: A dictionary of all information about the Image instance. 

558 ''' 

559 output = dict( 

560 width=self.width, 

561 height=self.height, 

562 channels=self.channels, 

563 num_channels=self.num_channels, 

564 bit_depth=self.bit_depth.name, 

565 dtype=self.bit_depth.dtype, 

566 bits=self.bit_depth.bits, 

567 signed=self.bit_depth.signed, 

568 type=self.bit_depth.type_, 

569 format_extension=None, 

570 format_bit_depths=None, 

571 format_channels=None, 

572 format_max_channels=None, 

573 format_custom_metadata=None, 

574 ) 

575 if self.format is not None: 

576 fmat = dict( 

577 format_extension=self.extension, 

578 format_bit_depths=self.format.bit_depths, 

579 format_channels=self.format.channels, 

580 format_max_channels=self.format.max_channels, 

581 format_custom_metadata=self.format.custom_metadata, 

582 ) 

583 output.update(fmat) 

584 return output 

585 

586 @property 

587 def shape(self): 

588 # type: () -> Tuple[int, int, int] 

589 ''' 

590 tuple[int]: (width, height, channels) of image. 

591 ''' 

592 return (self.width, self.height, self.num_channels) 

593 

594 @property 

595 def width(self): 

596 # type: () -> int 

597 ''' 

598 int: Width of image. 

599 ''' 

600 return self._data.shape[1] 

601 

602 @property 

603 def height(self): 

604 # type: () -> int 

605 ''' 

606 int: Height of image. 

607 ''' 

608 return self._data.shape[0] 

609 

610 @property 

611 def width_and_height(self): 

612 # type: () -> Tuple[int, int] 

613 ''' 

614 tupe[int]: (width, height) of image. 

615 ''' 

616 return (self.width, self.height) 

617 

618 @property 

619 def channels(self): 

620 # type: () -> list[Union[str, int]] 

621 ''' 

622 list[str or int]: List of channel names. 

623 ''' 

624 if 'channels' in self.metadata: 

625 return self.metadata['channels'] 

626 return cvt.get_channels_from_array(self._data) 

627 

628 @property 

629 def num_channels(self): 

630 # type: () -> int 

631 ''' 

632 int: Number of channels in image. 

633 ''' 

634 return len(self.channels) 

635 

636 @property 

637 def max_channels(self): 

638 # type: () -> Optional[int] 

639 ''' 

640 int: Maximum number of channels supported by image format. 

641 ''' 

642 if self.format is None: 

643 return None 

644 return self.format.max_channels 

645 

646 @property 

647 def channel_layers(self): 

648 # type: () -> list[str] 

649 ''' 

650 list[str]: List of channel layers. 

651 ''' 

652 channels = [str(x) for x in self.channels] 

653 with_layer = list(filter(lambda x: '.' in str(x), channels)) # type: list[str] 

654 wo_layer_name = list(filter(lambda x: '.' not in str(x), channels)) 

655 

656 # break out channels without layer names into groups of 4 

657 len_ = len(wo_layer_name) 

658 layers = [] 

659 for idx in range(0, len_, 4): 

660 layer_ = wo_layer_name[idx:min(idx + 4, len_)] 

661 layer = ''.join(layer_) 

662 layers.append(layer) 

663 

664 # append all unique layers of channels with layer names 

665 for chan in with_layer: 

666 layer = ''.join(chan.split('.')[0]) 

667 if layer not in layers: 

668 layers.append(layer) 

669 return layers 

670 

671 @property 

672 def bit_depth(self): 

673 # type: () -> BitDepth 

674 ''' 

675 BitDepth: Bit depth of image. 

676 ''' 

677 return BitDepth.from_dtype(self._data.dtype) 

678 

679 @property 

680 def extension(self): 

681 # type: () -> Optional[str] 

682 ''' 

683 str: Image format extension. 

684 ''' 

685 if self.format is None: 

686 return None 

687 return self.format.extension