Coverage for /home/ubuntu/cv-depot/python/cv_depot/core/enum.py: 100%

126 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 20:26 +0000

1from typing import Any # noqa F401 

2 

3from enum import Enum 

4import re 

5 

6from lunchbox.enforce import Enforce 

7import numpy as np 

8# ------------------------------------------------------------------------------ 

9 

10 

11''' 

12The enum module contains Enum classes for manging aspects of imagery such as bit 

13depths and video codecs. 

14''' 

15 

16 

17class EnumBase(Enum): 

18 def __repr__(self): 

19 # type: () -> str 

20 ''' 

21 str: String representation of enum. 

22 ''' 

23 return f'{self.__class__.__name__}.{self.name.upper()}' 

24 

25 @classmethod 

26 def from_string(cls, string): 

27 # type: (str) -> EnumBase 

28 ''' 

29 Constructs an enum instance from a given string. 

30 

31 Args: 

32 string (int): Enum string. 

33 

34 Raises: 

35 EnforceError: If value given is not a string. 

36 EnforceError: If no EnumBase type can be found for given string. 

37 

38 Returns: 

39 EnumBase: Enum instance. 

40 ''' 

41 msg = 'Value given is not a string. {a} != {b}.' 

42 Enforce(string, 'instance of', str, message=msg) 

43 

44 lut = {x.name: x for x in cls.__members__.values()} 

45 string = string.upper().replace('-', '_') 

46 

47 msg = '{a} is not a ' + cls.__name__ + ' option. ' 

48 msg += f'Options: {sorted(lut.keys())}.' 

49 Enforce(string, 'in', lut.keys(), message=msg) 

50 

51 return lut[string] 

52 

53 

54# BITDEPTH---------------------------------------------------------------------- 

55class BitDepth(EnumBase): 

56 ''' 

57 Legal bit depths. 

58 

59 Includes: 

60 

61 * FLOAT16 

62 * FLOAT32 

63 * UINT8 

64 * INT8 

65 ''' 

66 FLOAT16 = (np.float16, 16, True, float) 

67 FLOAT32 = (np.float32, 32, True, float) 

68 UINT8 = (np.uint8, 8, False, int) 

69 INT8 = (np.int8, 8, True, int) 

70 

71 def __init__(self, dtype, bits, signed, type_): 

72 # type: (Any, int, bool, type) -> None 

73 ''' 

74 Args: 

75 dtype (numpy.type): Numpy datatype. 

76 bits (int): Number of bits per channel. 

77 signed (bool): Whether channel scalars are signed. 

78 type_ (type): Python type of scalar. Options include: [int, float]. 

79 

80 Returns: 

81 BitDepth: BitDepth instance. 

82 ''' 

83 self.dtype = dtype # type: ignore 

84 self.bits = bits 

85 self.signed = signed 

86 self.type_ = type_ 

87 

88 def __repr__(self): 

89 # type: () -> str 

90 return f'BitDepth.{self.name.upper()}' 

91 

92 @staticmethod 

93 def from_dtype(dtype): 

94 # type: (Any) -> BitDepth 

95 ''' 

96 Construct a BitDepth instance from a given numpy datatype. 

97 

98 Args: 

99 dtype (numpy.type): Numpy datatype. Options include: 

100 [float16, float32, uint8, int8]. 

101 

102 Raises: 

103 TypeError: If invlaid dtype is given. 

104 

105 Returns: 

106 BitDepth: BitDepth instance of given type. 

107 ''' 

108 if dtype == np.float16: 

109 return BitDepth.FLOAT16 

110 elif dtype == np.float32: 

111 return BitDepth.FLOAT32 

112 elif dtype == np.uint8: 

113 return BitDepth.UINT8 

114 elif dtype == np.int8: 

115 return BitDepth.INT8 

116 

117 # needed because of numpy malarkey with __name__ 

118 if hasattr(dtype, '__name__'): 

119 dtype = dtype.__name__ 

120 msg = f'{dtype} is not a supported bit depth.' 

121 raise TypeError(msg) 

122 

123 

124# IMAGE------------------------------------------------------------------------- 

125class ImageFormat(Enum): 

126 ''' 

127 Legal image formats. 

128 

129 Includes: 

130 

131 * EXR 

132 * PNG 

133 * JPEG 

134 * TIFF 

135 ''' 

136 EXR = ( 

137 'exr', [BitDepth.FLOAT16, BitDepth.FLOAT32], list('rgba') + ['...'], 

138 1023, True 

139 ) 

140 PNG = ('png', [BitDepth.UINT8], list('rgba'), 4, False) 

141 JPEG = ('jpeg', [BitDepth.UINT8], list('rgb'), 3, False) 

142 TIFF = ( 

143 'tiff', [BitDepth.INT8, BitDepth.UINT8, BitDepth.FLOAT32], 

144 list('rgba') + ['...'], 500, False 

145 ) 

146 

147 def __init__(self, extension, bit_depths, channels, max_channels, 

148 custom_metadata): 

149 # type: (str, list[BitDepth], list[str], int, bool) -> None 

150 ''' 

151 Args: 

152 extension (str): Name of file extension. 

153 bit_depths (list[BitDepth]): Supported bit depths. 

154 channels (list[str]): Supported channels. 

155 max_channels (int): Maximum number of channels supported. 

156 custom_metadata (bool): Custom metadata support. 

157 

158 Returns: 

159 ImageFormat: ImageFormat instance. 

160 ''' 

161 self.extension = extension 

162 self.bit_depths = bit_depths 

163 self.channels = channels 

164 self.max_channels = max_channels 

165 self.custom_metadata = custom_metadata 

166 

167 def __repr__(self): 

168 # type: () -> str 

169 return f''' 

170<ImageFormat.{self.name.upper()}> 

171 extension: {self.extension} 

172 bit_depths: {[x.name for x in self.bit_depths]} 

173 channels: {self.channels} 

174 max_channels: {self.max_channels} 

175custom_metadata: {self.custom_metadata}'''[1:] 

176 

177 @staticmethod 

178 def from_extension(extension): 

179 ''' 

180 Construct an ImageFormat instance for a given file extension. 

181 

182 Args: 

183 extension (str): File extension. 

184 

185 Raises: 

186 TypeError: If extension is illegal. 

187 

188 Returns: 

189 ImageFormat: ImageFormat instance of given extension. 

190 ''' 

191 exr_re = r'^\.?exr$' 

192 png_re = r'^\.?png$' 

193 jpeg_re = r'^\.?jpe?g$' 

194 tiff_re = r'^\.?tiff?$' 

195 

196 if re.search(exr_re, extension, re.I): 

197 return ImageFormat.EXR 

198 

199 elif re.search(png_re, extension, re.I): 

200 return ImageFormat.PNG 

201 

202 elif re.search(jpeg_re, extension, re.I): 

203 return ImageFormat.JPEG 

204 

205 elif re.search(tiff_re, extension, re.I): 

206 return ImageFormat.TIFF 

207 

208 msg = f'ImageFormat not found for given extension: {extension}' 

209 raise TypeError(msg) 

210 

211 

212# VIDEO------------------------------------------------------------------------- 

213class VideoFormat(Enum): 

214 ''' 

215 Legal video formats. 

216 

217 Includes: 

218 

219 * MP4 

220 * MPEG 

221 * MOV 

222 * M4V 

223 ''' 

224 MP4 = ('mp4', [BitDepth.UINT8], list('rgb'), 3, False) 

225 MPEG = ('mpeg', [BitDepth.UINT8], list('rgb'), 3, False) 

226 MOV = ('mov', [BitDepth.UINT8], list('rgb'), 3, False) 

227 M4V = ('m4v', [BitDepth.UINT8], list('rgb'), 3, False) 

228 

229 def __init__(self, extension, bit_depths, channels, max_channels, 

230 custom_metadata): 

231 # type: (str, list[BitDepth], list[str], int, bool) -> None 

232 ''' 

233 Args: 

234 extension (str): Name of file extension. 

235 bit_depths (list[BitDepth]): Supported bit depths. 

236 channels (list[str]): Supported channels. 

237 max_channels (int): Maximum number of channels supported. 

238 custom_metadata (bool): Custom metadata support. 

239 

240 Returns: 

241 VideoFormat: VideoFormat instance. 

242 ''' 

243 self.extension = extension 

244 self.bit_depths = bit_depths 

245 self.channels = channels 

246 self.max_channels = max_channels 

247 self.custom_metadata = custom_metadata 

248 

249 def __repr__(self): 

250 # type: () -> str 

251 return f''' 

252<VideoFormat.{self.name.upper()}> 

253 extension: {self.extension} 

254 bit_depths: {[x.name for x in self.bit_depths]} 

255 channels: {self.channels} 

256 max_channels: {self.max_channels} 

257custom_metadata: {self.custom_metadata}'''[1:] 

258 

259 @staticmethod 

260 def from_extension(extension): 

261 ''' 

262 Construct an VideoFormat instance for a given file extension. 

263 

264 Args: 

265 extension (str): File extension. 

266 

267 Raises: 

268 TypeError: If extension is invalid. 

269 

270 Returns: 

271 VideoFormat: VideoFormat instance of given extension. 

272 ''' 

273 mp4_re = r'^\.?mp4$' 

274 mpeg_re = r'^\.?mpe?g$' 

275 mov_re = r'^\.?mov$' 

276 m4v_re = r'^\.?m4v$' 

277 

278 if re.search(mp4_re, extension, re.I): 

279 return VideoFormat.MP4 

280 

281 if re.search(mpeg_re, extension, re.I): 

282 return VideoFormat.MPEG 

283 

284 if re.search(mov_re, extension, re.I): 

285 return VideoFormat.MOV 

286 

287 if re.search(m4v_re, extension, re.I): 

288 return VideoFormat.M4V 

289 

290 msg = f'VideoFormat not found for given extension: {extension}' 

291 raise TypeError(msg) 

292# ------------------------------------------------------------------------------ 

293 

294 

295class VideoCodec(Enum): 

296 ''' 

297 Legal video codecs. 

298 

299 Includes: 

300 

301 * H264 

302 * H265 

303 ''' 

304 H264 = ('h264', 'h264') 

305 H265 = ('h265', 'hevc') 

306 

307 def __init__(self, string, ffmpeg_code): 

308 # type: (str, str) -> None 

309 ''' 

310 Args: 

311 string (str): String representation of codec. 

312 ffmpeg_code (str): FFMPEG code. 

313 ''' 

314 self.string = string 

315 self.ffmpeg_code = ffmpeg_code 

316 

317 def __repr__(self): 

318 # type: () -> str 

319 return f''' 

320<VideoCodec.{self.name.upper()}> 

321 string: {self.string} 

322 ffmpeg_code: {self.ffmpeg_code}'''[1:] 

323 

324 

325# DIRECTION--------------------------------------------------------------------- 

326class Direction(EnumBase): 

327 ''' 

328 Legal directions. 

329 

330 Includes: 

331 

332 * TOP 

333 * BOTTOM 

334 * LEFT 

335 * RIGHT 

336 ''' 

337 TOP = ('top') 

338 BOTTOM = ('bottom') 

339 LEFT = ('left') 

340 RIGHT = ('right') 

341 

342 

343class Anchor(EnumBase): 

344 ''' 

345 Legal anchors. 

346 

347 Includes: 

348 

349 * TOP_LEFT 

350 * TOP_CENTER 

351 * TOP_RIGHT 

352 * CENTER_LEFT 

353 * CENTER_CENTER 

354 * CENTER_RIGHT 

355 * BOTTOM_LEFT 

356 * BOTTOM_CENTER 

357 * BOTTOM_RIGHT 

358 ''' 

359 TOP_LEFT = ('top', 'left') 

360 TOP_CENTER = ('top', 'center') 

361 TOP_RIGHT = ('top', 'right') 

362 CENTER_LEFT = ('center', 'left') 

363 CENTER_CENTER = ('center', 'center') 

364 CENTER_RIGHT = ('center', 'right') 

365 BOTTOM_LEFT = ('bottom', 'left') 

366 BOTTOM_CENTER = ('bottom', 'center') 

367 BOTTOM_RIGHT = ('bottom', 'right')