Coverage for /home/ubuntu/flatiron/python/flatiron/core/dataset.py: 100%

243 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Any, Optional, Union # noqa F401 

2from flatiron.core.types import Filepath, OptArray, OptInt, OptLabels # noqa F401 

3 

4from pathlib import Path 

5import os 

6import random 

7import re 

8 

9from lunchbox.enforce import Enforce 

10from tqdm.notebook import tqdm 

11import cv_depot.api as cvd 

12import humanfriendly as hf 

13import numpy as np 

14import pandas as pd 

15 

16import flatiron.core.tools as fict 

17# ------------------------------------------------------------------------------ 

18 

19 

20class Dataset: 

21 @classmethod 

22 def read_csv(cls, filepath, **kwargs): 

23 # type: (Filepath, Any) -> Dataset 

24 ''' 

25 Construct Dataset instance from given csv filepath. 

26 

27 Args: 

28 filepath (str or Path): Info CSV filepath. 

29 

30 Raises: 

31 EnforceError: If filepath does not exist or is not a CSV. 

32 

33 Returns: 

34 Dataset: Dataset instance. 

35 ''' 

36 fp = Path(filepath) 

37 msg = f'Filepath does not exist: {fp}' 

38 Enforce(fp.is_file(), '==', True, message=msg) 

39 

40 msg = f'Filepath extension must be csv. Given filepath: {fp}' 

41 Enforce(fp.suffix.lower()[1:], '==', 'csv', message=msg) 

42 # ---------------------------------------------------------------------- 

43 

44 info = pd.read_csv(filepath) 

45 return cls(info, **kwargs) 

46 

47 @classmethod 

48 def read_directory(cls, directory, **kwargs): 

49 # type: (Filepath, Any) -> Dataset 

50 ''' 

51 Construct dataset from directory. 

52 

53 Args: 

54 directory (str or Path): Dataset directory. 

55 

56 Raises: 

57 EnforceError: If directory does not exist. 

58 EnforceError: If more or less than 1 CSV file found in directory. 

59 

60 Returns: 

61 Dataset: Dataset instance. 

62 ''' 

63 msg = f'Directory does not exist: {directory}' 

64 Enforce(Path(directory).is_dir(), '==', True, message=msg) 

65 

66 files = [Path(directory, x) for x in os.listdir(directory)] # type: Any 

67 files = list(filter(lambda x: x.suffix.lower()[1:] == 'csv', files)) 

68 files = sorted([x.as_posix() for x in files]) 

69 msg = 'Dataset directory must contain only 1 CSV file. ' 

70 msg += f'CSV files found: {files}' 

71 Enforce(len(files), '==', 1, message=msg) 

72 # ---------------------------------------------------------------------- 

73 

74 return cls.read_csv(files[0], **kwargs) 

75 

76 def __init__( 

77 self, info, ext_regex='npy|exr|png|jpeg|jpg|tiff', calc_file_size=True, 

78 labels=None, label_axis=-1 

79 ): 

80 # type: (pd.DataFrame, str, bool, OptLabels, int) -> None 

81 ''' 

82 Construct a Dataset instance. 

83 If labels is an integer it will assumed to be an axis which the 

84 data will be split upon. 

85 

86 Args: 

87 info (pd.DataFrame): Info DataFrame. 

88 ext_regex (str, optional): File extension pattern. 

89 Default: 'npy|exr|png|jpeg|jpg|tiff'. 

90 calc_file_size (bool, optional): Calculate file size in GB. 

91 Default: True. 

92 labels (object, optional): Label channels. Default: None. 

93 label_axis (int, optional): Label axis. Default: -1. 

94 

95 Raises: 

96 EnforceError: If info is not an instance of DataFrame. 

97 EnforceError: If required columns not found in info. 

98 ''' 

99 Enforce(info, 'instance of', pd.DataFrame) 

100 

101 # columns 

102 columns = info.columns.tolist() 

103 required = ['asset_path', 'filepath_relative'] 

104 diff = sorted(list(set(required).difference(columns))) 

105 msg = f'Required columns not found in info: {diff}' 

106 Enforce(len(diff), '==', 0, message=msg) 

107 

108 # root path 

109 root = info.asset_path.unique().tolist() 

110 msg = f'Info must contain only 1 root path. Paths found: {root}' 

111 Enforce(len(root), '==', 1, message=msg) 

112 root = root[0] 

113 msg = f'Directory does not exist: {root}' 

114 Enforce(Path(root).is_dir(), '==', True, message=msg) 

115 

116 # info files 

117 info['filepath'] = info.filepath_relative \ 

118 .apply(lambda x: Path(root, x).as_posix()) 

119 mask = info.filepath.apply(lambda x: not Path(x).is_file()) 

120 absent = info.loc[mask, 'filepath'].tolist() 

121 msg = f'Files do not exist: {absent}' 

122 Enforce(len(absent), '==', 0, message=msg) 

123 

124 # extension 

125 mask = info.filepath \ 

126 .apply(lambda x: Path(x).suffix.lower()[1:]) \ 

127 .apply(lambda x: re.search(ext_regex, x, re.I) is None) 

128 bad_ext = sorted(info.loc[mask, 'filepath'].tolist()) 

129 msg = f'Found files extensions that do not match ext_regex: {bad_ext}' 

130 Enforce(len(bad_ext), '==', 0, message=msg) 

131 

132 # frame indicators 

133 frame_regex = r'_(f|c)(\d+)\.' + f'({ext_regex})$' 

134 mask = info.filepath.apply(lambda x: re.search(frame_regex, x) is None) 

135 bad_frames = info.loc[mask, 'filepath'].tolist() 

136 msg = 'Found files missing frame indicators. ' 

137 msg += f"File names must match '{frame_regex}'. " 

138 msg += f'Invalid frames: {bad_frames}' 

139 Enforce(len(bad_frames), '==', 0, message=msg) 

140 

141 # frame column 

142 info['frame'] = info.filepath \ 

143 .apply(lambda x: re.search(frame_regex, x).group(2)).astype(int) # type: ignore 

144 

145 # gb column 

146 info['gb'] = np.nan 

147 if calc_file_size: 

148 info['gb'] = info.filepath \ 

149 .apply(lambda x: os.stat(x).st_size / 10**9) 

150 

151 # loaded column 

152 info['loaded'] = False 

153 # ---------------------------------------------------------------------- 

154 

155 # reorganize columns 

156 cols = [ 

157 'gb', 'frame', 'asset_path', 'filepath_relative', 'filepath', 

158 'loaded' 

159 ] 

160 cols = cols + info.drop(cols, axis=1).columns.tolist() 

161 info = info[cols] 

162 

163 self._info = info # type: pd.DataFrame 

164 self.data = None # type: OptArray 

165 self.labels = labels 

166 self.label_axis = label_axis 

167 self._ext_regex = ext_regex 

168 self._calc_file_size = calc_file_size 

169 self._sample_gb = np.nan # type: Union[float, np.ndarray] 

170 

171 @property 

172 def info(self): 

173 # type: () -> pd.DataFrame 

174 ''' 

175 Returns: 

176 DataFrame: Copy of info DataFrame. 

177 ''' 

178 return self._info.copy() 

179 

180 @property 

181 def filepaths(self): 

182 # type: () -> list[str] 

183 ''' 

184 Returns: 

185 list[str]: Filepaths sorted by frame. 

186 ''' 

187 return self._info.sort_values('frame').filepath.tolist() 

188 

189 @property 

190 def asset_path(self): 

191 # type: () -> str 

192 ''' 

193 Returns: 

194 str: Asset path of Dataset. 

195 ''' 

196 return self.info.loc[0, 'asset_path'] 

197 

198 @property 

199 def asset_name(self): 

200 # type: () -> str 

201 ''' 

202 Returns: 

203 str: Asset name of Dataset. 

204 ''' 

205 return Path(self.asset_path).name 

206 

207 @property 

208 def stats(self): 

209 # type: () -> pd.DataFrame 

210 ''' 

211 Generates a table of statistics of info data. 

212 

213 Metrics include: 

214 

215 * min 

216 * max 

217 * mean 

218 * std 

219 * loaded 

220 * total 

221 

222 Units include: 

223 

224 * gb 

225 * frame 

226 * sample 

227 

228 Returns: 

229 DataFrame: Table of statistics. 

230 ''' 

231 info = self.info 

232 a = self._get_stats(info) 

233 b = self._get_stats(info.loc[info.loaded]) \ 

234 .loc[['total']].rename(lambda x: 'loaded') 

235 stats = pd.concat([a, b]) 

236 stats['sample'] = np.nan 

237 

238 if self.data is not None: 

239 loaded = round(self.data.nbytes / 10**9, 2) 

240 stats.loc['loaded', 'gb'] = loaded 

241 

242 # sample stats 

243 total = info['gb'].sum() / self._sample_gb 

244 stats.loc['loaded', 'sample'] = self.data.shape[0] 

245 stats.loc['total', 'sample'] = total 

246 stats.loc['mean', 'sample'] = total / info.shape[0] 

247 stats['sample'] = stats['sample'].apply(lambda x: round(x, 0)) 

248 

249 index = ['min', 'max', 'mean', 'std', 'loaded', 'total'] 

250 stats = stats.loc[index] 

251 return stats 

252 

253 @staticmethod 

254 def _get_stats(info): 

255 # type: (pd.DataFrame) -> pd.DataFrame 

256 ''' 

257 Creates table of statistics from given info DataFrame. 

258 

259 Args: 

260 info (pd.DataFrame): Info DataFrame. 

261 

262 Returns: 

263 pd.DataFrame: Stats DataFrame. 

264 ''' 

265 stats = info.describe() 

266 rows = ['min', 'max', 'mean', 'std', 'count'] 

267 stats = stats.loc[rows] 

268 stats.loc['total'] = info[stats.columns].sum() 

269 stats.loc['total', 'frame'] = stats.loc['count', 'frame'] 

270 stats.loc['mean', 'frame'] = np.nan 

271 stats.loc['std', 'frame'] = np.nan 

272 stats = stats.map(lambda x: round(x, 2)) 

273 stats.drop('count', inplace=True) 

274 return stats 

275 

276 def __repr__(self): 

277 # type: () -> str 

278 ''' 

279 Returns: 

280 str: Info statistics. 

281 ''' 

282 msg = f''' 

283 <Dataset> 

284 ASSET_NAME: {self.asset_name} 

285 ASSET_PATH: {self.asset_path} 

286 STATS: 

287 '''[1:] 

288 msg = fict.unindent(msg, spaces=8) 

289 cols = ['gb', 'frame', 'sample'] 

290 stats = str(self.stats[cols]) 

291 stats = '\n '.join(stats.split('\n')) 

292 msg = msg + stats 

293 return msg 

294 

295 def __len__(self): 

296 # tyope: () -> int 

297 ''' 

298 Returns: 

299 int: Number of frames. 

300 ''' 

301 return len(self._info) 

302 

303 def __getitem(self, frame): 

304 # type: (int) -> Any 

305 ''' 

306 Get data by frame. 

307 Thisi is needed to avoid recursion errors when overloading __getitem__. 

308 

309 Raises: 

310 IndexError: If frame is missing or multiple frames were found. 

311 

312 Returns: 

313 object: Data of given frame. 

314 ''' 

315 return self._read_file(self.get_filepath(frame)) 

316 

317 def __getitem__(self, frame): 

318 # type: (int) -> Any 

319 ''' 

320 Get data by frame. 

321 

322 Raises: 

323 IndexError: If frame is missing or multiple frames were found. 

324 

325 Returns: 

326 object: Data of given frame. 

327 ''' 

328 return self.__getitem(frame) 

329 

330 def get_filepath(self, frame): 

331 # type: (int) -> Any 

332 ''' 

333 Get filepath of given frame. 

334 

335 Raises: 

336 IndexError: If frame is missing or multiple frames were found. 

337 

338 Returns: 

339 str: Filepath of given frame. 

340 ''' 

341 info = self._info 

342 mask = info.frame == frame 

343 filepaths = info.loc[mask, 'filepath'].tolist() 

344 if len(filepaths) == 0: 

345 raise IndexError(f'Missing frame {frame}.') 

346 elif len(filepaths) > 1: 

347 raise IndexError(f'Multiple frames found for {frame}.') 

348 return filepaths[0] 

349 

350 def get_arrays(self, frame): 

351 # type: (int) -> list[np.ndarray] 

352 ''' 

353 Get data and convert into numpy arrays according to labels. 

354 

355 Args: 

356 frame (int): Frame. 

357 

358 Raises: 

359 IndexError: If frame is missing or multiple frames were found. 

360 

361 Returns: 

362 list[np.ndarray]: List of arrays from the given frame. 

363 ''' 

364 labels = self.labels # type: Any 

365 if labels is None or labels == []: 

366 return [self._read_file_as_array(self.get_filepath(frame))] 

367 

368 item = self.__getitem(frame) 

369 

370 # get labels 

371 if not isinstance(labels, list): 

372 labels = [labels] 

373 

374 # if item is numpy array return a np.split 

375 if isinstance(item, np.ndarray): 

376 arrays = list(np.split(item, labels, axis=self.label_axis)) 

377 

378 # otherwise item is an Image with channels 

379 else: 

380 chan = list(filter(lambda x: x not in labels, item.channels)) 

381 img = item.to_bit_depth(cvd.BitDepth.FLOAT16) 

382 arrays = [img[:, :, chan].data, img[:, :, labels].data] 

383 

384 # enforce shape equivalence 

385 max_dim = max(*[x.ndim for x in arrays]) 

386 output = [] 

387 for array in arrays: 

388 if array.ndim != max_dim: 

389 ndim = list(range(max_dim - array.ndim + 1, max_dim)) 

390 array = np.expand_dims(array, axis=ndim) 

391 output.append(array) 

392 return output 

393 

394 def _read_file(self, filepath): 

395 # type: (str) -> Any 

396 ''' 

397 Read given file. 

398 

399 Args: 

400 filepath (str): Filepath. 

401 

402 Raises: 

403 IOError: If extension is not supported. 

404 

405 Returns: 

406 object: File content. 

407 ''' 

408 ext = Path(filepath).suffix[1:].lower() 

409 if ext == 'npy': 

410 return np.load(filepath) 

411 

412 formats = [x.lower() for x in cvd.ImageFormat.__members__.keys()] 

413 formats += ['jpg'] 

414 if ext in formats: 

415 return cvd.Image.read(filepath) 

416 

417 raise IOError(f'Unsupported extension: {ext}') 

418 

419 def _read_file_as_array(self, filepath): 

420 # type: (str) -> np.ndarray 

421 ''' 

422 Read file as numpy array. 

423 

424 Args: 

425 filepath (str): Filepath. 

426 

427 Returns: 

428 np.ndarray: Array. 

429 ''' 

430 item = self._read_file(filepath) 

431 

432 ext = Path(filepath).suffix[1:].lower() 

433 if ext == 'npy': 

434 return item 

435 return item.data 

436 

437 @staticmethod 

438 def _resolve_limit(limit): 

439 # type: (Union[int, str, None]) -> tuple[int, str] 

440 ''' 

441 Resolves a given limit into a number of samples and limit type. 

442 

443 Args: 

444 limit (str, int, None): Limit descriptor. 

445 

446 Returns: 

447 tuple[int, str]: Number of samples and limit type. 

448 ''' 

449 if isinstance(limit, int): 

450 return limit, 'samples' 

451 

452 elif isinstance(limit, str): 

453 return hf.parse_size(limit), 'memory' 

454 

455 return -1, 'None' 

456 

457 def load(self, limit=None, shuffle=False, reshape=True): 

458 # type: (Optional[Union[str, int]], bool, bool) -> Dataset 

459 ''' 

460 Load data from files. 

461 

462 Args: 

463 limit (str or int, optional): Limit data by number of samples or 

464 memory size. Default: None. 

465 shuffle (bool, optional): Shuffle frames before loading. 

466 Default: False. 

467 reshape (bool, optional): Reshape concatenated data to incorpate 

468 frames as the first dimension: (FRAME, ...). Analogous to the 

469 first dimension being batch. Default: True. 

470 

471 Returns: 

472 Dataset: self. 

473 ''' 

474 self.unload() 

475 

476 # resolve limit 

477 limit_, limit_type = self._resolve_limit(limit) 

478 

479 # shuffle rows 

480 rows = list(self.info.iterrows()) 

481 if shuffle: 

482 random.shuffle(rows) 

483 

484 # frame vars 

485 frames = [] 

486 memory = 0 

487 samples = 0 

488 

489 # tqdm message 

490 desc = 'Loading Dataset Files' 

491 if limit_type != 'None': 

492 desc = f'May not total to 100% - {desc}' 

493 

494 # load frames 

495 for i, row in tqdm(rows, desc=desc): 

496 if limit_type == 'samples' and samples >= limit_: 

497 break 

498 elif limit_type == 'memory' and memory >= limit_: 

499 break 

500 

501 frame = self._read_file_as_array(row.filepath) 

502 if reshape: 

503 frame = frame[np.newaxis, ...] 

504 frames.append(frame) 

505 

506 self._info.loc[i, 'loaded'] = True 

507 memory += frame.nbytes 

508 samples += frame.shape[0] 

509 

510 # concatenate data 

511 data = np.concatenate(frames, axis=0) 

512 

513 # limit array size by samples 

514 if limit_type == 'samples': 

515 data = data[:limit_] 

516 

517 # limit array size by memory 

518 elif limit_type == 'memory': 

519 sample_mem = data[0].nbytes 

520 delta = data.nbytes - limit_ 

521 if delta > 0: 

522 k = int(delta / sample_mem) 

523 n = data.shape[0] 

524 data = data[:n - k] 

525 

526 # set class members 

527 self.data = data 

528 self._sample_gb = data[0].nbytes / 10**9 

529 return self 

530 

531 def unload(self): 

532 # type: () -> Dataset 

533 ''' 

534 Delete self.data and reset self.info. 

535 

536 Returns: 

537 Dataset: self. 

538 ''' 

539 del self.data 

540 self.data = None 

541 self._info['loaded'] = False 

542 return self 

543 

544 def xy_split(self): 

545 # type: () -> tuple[np.ndarray, np.ndarray] 

546 ''' 

547 Split data into x and y arrays, according to self.labels as the split 

548 index and self.label_axis as the split axis. 

549 

550 Raises: 

551 EnforceError: If data has not been loaded. 

552 EnforceError: If self.labels is not a list of a single integer. 

553 

554 Returns: 

555 tuple[np.ndarray]: x and y arrays. 

556 ''' 

557 msg = 'Data not loaded. Please call load method.' 

558 Enforce(self.data, 'instance of', np.ndarray, message=msg) 

559 

560 msg = 'self.labels must be a list of a single integer. ' 

561 msg += f'Provided labels: {self.labels}.' 

562 

563 labels = self.labels # type: Any 

564 Enforce(labels, 'instance of', list, message=msg) 

565 Enforce(len(labels), '==', 1, message=msg) 

566 Enforce(labels[0], 'instance of', int, message=msg) 

567 # ---------------------------------------------------------------------- 

568 

569 return np.split(self.data, self.labels, axis=self.label_axis) # type: ignore 

570 

571 def train_test_split( 

572 self, 

573 test_size=0.2, # type: float 

574 limit=None, # type: OptInt 

575 shuffle=True, # type: bool 

576 seed=None, # type: OptInt 

577 ): 

578 # type: (...) -> tuple[Dataset, Dataset] 

579 ''' 

580 Split into train and test Datasets. 

581 

582 Args: 

583 test_size (float, optional): Test set size as a proportion. 

584 Default: 0.2. 

585 limit (int, optional): Limit the total length of train and test. 

586 Default: None. 

587 shuffle (bool, optional): Randomize data before splitting. 

588 Default: True. 

589 seed (float, optional): Seed number between 0 and 1. Default: None. 

590 

591 Returns: 

592 tuple[Dataset]: Train Dataset, Test Dataset. 

593 ''' 

594 train, test = fict.train_test_split( 

595 self.info, 

596 test_size=test_size, limit=limit, shuffle=shuffle, seed=seed 

597 ) 

598 train.reset_index(drop=True, inplace=True) 

599 test.reset_index(drop=True, inplace=True) 

600 kwargs = dict( 

601 ext_regex=self._ext_regex, 

602 calc_file_size=self._calc_file_size, 

603 labels=self.labels, 

604 label_axis=self.label_axis 

605 ) 

606 return Dataset(train, **kwargs), Dataset(test, **kwargs) # type: ignore