Coverage for /home/ubuntu/openexr-tools/python/openexr_tools/tools.py: 100%
96 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 15:56 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 15:56 +0000
1from numpy.typing import DTypeLike, NDArray # noqa F401
2from typing import Tuple, Union # noqa F401
4from copy import deepcopy
5from pathlib import Path
7import Imath as imath
8import numpy as np
9import OpenEXR as openexr
11from openexr_tools.enum import ImageCodec
12# ------------------------------------------------------------------------------
15def read_exr(fullpath):
16 # type: (Union[str, Path]) -> Tuple[NDArray, dict]
17 '''
18 Reads an OpenEXR image file.
20 Args:
21 fullpath (str or Path): Image file path.
23 Raises:
24 IOError: If given filepath is not an EXR file.
26 Returns:
27 tuple[numpy.NDArray, dict]: Image and metadata.
28 '''
29 if isinstance(fullpath, Path):
30 fullpath = fullpath.absolute().as_posix()
32 if not openexr.isOpenExrFile(fullpath):
33 msg = f'{fullpath} is not an EXR file.'
34 raise IOError(msg)
36 img = openexr.InputFile(fullpath)
37 metadata = img.header()
38 win = metadata['dataWindow']
39 x = (win.max.x - win.min.x) + 1
40 y = (win.max.y - win.min.y) + 1
42 # EXR headers store channel data in a map, so there can be no suuport for
43 # arbitrary channel order persistence.
44 image_stack = []
45 temp = sorted(metadata['channels'].keys())
46 channels = []
47 for chan in list('RGBA'):
48 if chan in temp:
49 channels.append(chan)
50 temp.remove(chan)
51 for chan in temp:
52 channels.append(chan)
54 for chan in channels:
55 data = metadata['channels'][chan]
56 temp_img = img.channel(chan, data.type)
58 # FLOAT is float32, HALF is float16
59 dtype = np.float32 # type: DTypeLike
60 if str(data.type) == 'HALF':
61 dtype = np.float16
63 temp_img = np.frombuffer(temp_img, dtype).reshape((y, x))
64 image_stack.append(temp_img)
66 image = np.dstack(image_stack) # type: np.ndarray
67 metadata['channels'] = [x.lower() for x in channels]
68 metadata['num_channels'] = len(channels)
70 # convert to compression enum
71 comp = metadata['compression']
72 metadata['compression'] = ImageCodec.from_exr_code(comp.v)
74 for key, val in metadata.items():
75 if isinstance(val, bytes):
76 metadata[key] = val.decode('utf-8')
78 return image, metadata
81def clean_exr_metadadata(image, metadata):
82 # type: (NDArray, dict) -> dict
83 '''
84 Uses given image and metadata dictionary to create EXR metadata for use in
85 writing EXRs.
87 Args:
88 image (numpy.NDArray): Image.
89 metadata (dict): Metadata dictionary.
91 Returns:
92 dict: Clean metadata.
93 '''
94 metadata = deepcopy(metadata)
96 # ensure length of channels is the same length as image's channel dimension
97 num_channels = 1
98 if len(image.shape) > 2:
99 num_channels = image.shape[2]
101 channels = []
102 if 'channels' in metadata:
103 channels = metadata['channels']
105 # do not assume rgba channel names for unnamed channels
106 delta = num_channels - len(channels)
107 for i in range(delta):
108 channels.append(f'aux_{i:04d}')
110 # use l channel name for grayscale images
111 if len(channels) == 1 and channels[0] == 'aux_0000':
112 channels = ['l']
114 metadata['channels'] = channels
116 # remove forbidden keys
117 forbidden = [
118 'compression',
119 'dataWindow',
120 'displayWindow',
121 'lineOrder',
122 'pixelAspectRatio',
123 'screenWindowCenter',
124 'screenWindowWidth',
125 ]
126 intersect = set(metadata.keys()).intersection(forbidden)
127 for key in intersect:
128 del metadata[key]
130 return metadata
133def write_exr(fullpath, image, metadata, codec=ImageCodec.PIZ):
134 # type: (Union[str, Path], NDArray, dict, ImageCodec) -> None
135 '''
136 Writes image data and metadata as EXR to given file path.
138 Args:
139 fullpath (str or Path): Path to EXR file.
140 image (numpy.NDArray): Image data.
141 metadata (dict): Dictionary of EXR metadata.
142 codec (ImageCodec, optional): Image codec. Default: ImageCodec.PIZ.
144 Raises:
145 TypeError: If image is not float16 or float32.
146 '''
147 dtype = image.dtype
148 if dtype not in [np.float16, np.float32]:
149 msg = f'EXR cannot be saved with array of dtype: {dtype}.'
150 raise TypeError(msg)
152 # determine bit depth of EXR
153 ctype = imath.Channel(imath.PixelType(imath.PixelType.FLOAT))
154 if dtype == np.float16:
155 ctype = imath.Channel(imath.PixelType(imath.PixelType.HALF))
157 # ensure metadata is clean
158 metadata = clean_exr_metadadata(image, metadata)
160 # ensure image has a channel axis
161 if len(image.shape) < 3:
162 shape = list(image.shape) + [1]
163 image = image.reshape(shape)
165 # create EXR data and channels objects
166 channels = {}
167 data = {}
168 for i, chan in enumerate(metadata['channels']):
169 chan = str(chan)
170 if chan in list('lrgba'):
171 chan = chan.upper()
172 data[chan] = image[:, :, i].tobytes()
173 channels[chan] = ctype
175 # create EXR header
176 y, x = image.shape[:2]
177 header = openexr.Header(x, y)
179 # all strings must be bytes
180 for key, val in metadata.items():
181 if isinstance(val, str):
182 val = val.encode('utf-8')
183 header[key] = val
185 header['channels'] = channels
186 header['compression'] = imath.Compression(codec.exr_code)
188 # write EXR data
189 if isinstance(fullpath, Path):
190 fullpath = fullpath.absolute().as_posix()
192 output = openexr.OutputFile(fullpath, header)
193 output.writePixels(data)