Coverage for /home/ubuntu/flatiron/python/flatiron/core/dataset.py: 100%
243 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Any, Optional, Union # noqa F401
2from flatiron.core.types import Filepath, OptArray, OptInt, OptLabels # noqa F401
4from pathlib import Path
5import os
6import random
7import re
9from lunchbox.enforce import Enforce
10from tqdm.notebook import tqdm
11import cv_depot.api as cvd
12import humanfriendly as hf
13import numpy as np
14import pandas as pd
16import flatiron.core.tools as fict
17# ------------------------------------------------------------------------------
20class Dataset:
21 @classmethod
22 def read_csv(cls, filepath, **kwargs):
23 # type: (Filepath, Any) -> Dataset
24 '''
25 Construct Dataset instance from given csv filepath.
27 Args:
28 filepath (str or Path): Info CSV filepath.
30 Raises:
31 EnforceError: If filepath does not exist or is not a CSV.
33 Returns:
34 Dataset: Dataset instance.
35 '''
36 fp = Path(filepath)
37 msg = f'Filepath does not exist: {fp}'
38 Enforce(fp.is_file(), '==', True, message=msg)
40 msg = f'Filepath extension must be csv. Given filepath: {fp}'
41 Enforce(fp.suffix.lower()[1:], '==', 'csv', message=msg)
42 # ----------------------------------------------------------------------
44 info = pd.read_csv(filepath)
45 return cls(info, **kwargs)
47 @classmethod
48 def read_directory(cls, directory, **kwargs):
49 # type: (Filepath, Any) -> Dataset
50 '''
51 Construct dataset from directory.
53 Args:
54 directory (str or Path): Dataset directory.
56 Raises:
57 EnforceError: If directory does not exist.
58 EnforceError: If more or less than 1 CSV file found in directory.
60 Returns:
61 Dataset: Dataset instance.
62 '''
63 msg = f'Directory does not exist: {directory}'
64 Enforce(Path(directory).is_dir(), '==', True, message=msg)
66 files = [Path(directory, x) for x in os.listdir(directory)] # type: Any
67 files = list(filter(lambda x: x.suffix.lower()[1:] == 'csv', files))
68 files = sorted([x.as_posix() for x in files])
69 msg = 'Dataset directory must contain only 1 CSV file. '
70 msg += f'CSV files found: {files}'
71 Enforce(len(files), '==', 1, message=msg)
72 # ----------------------------------------------------------------------
74 return cls.read_csv(files[0], **kwargs)
76 def __init__(
77 self, info, ext_regex='npy|exr|png|jpeg|jpg|tiff', calc_file_size=True,
78 labels=None, label_axis=-1
79 ):
80 # type: (pd.DataFrame, str, bool, OptLabels, int) -> None
81 '''
82 Construct a Dataset instance.
83 If labels is an integer it will assumed to be an axis which the
84 data will be split upon.
86 Args:
87 info (pd.DataFrame): Info DataFrame.
88 ext_regex (str, optional): File extension pattern.
89 Default: 'npy|exr|png|jpeg|jpg|tiff'.
90 calc_file_size (bool, optional): Calculate file size in GB.
91 Default: True.
92 labels (object, optional): Label channels. Default: None.
93 label_axis (int, optional): Label axis. Default: -1.
95 Raises:
96 EnforceError: If info is not an instance of DataFrame.
97 EnforceError: If required columns not found in info.
98 '''
99 Enforce(info, 'instance of', pd.DataFrame)
101 # columns
102 columns = info.columns.tolist()
103 required = ['asset_path', 'filepath_relative']
104 diff = sorted(list(set(required).difference(columns)))
105 msg = f'Required columns not found in info: {diff}'
106 Enforce(len(diff), '==', 0, message=msg)
108 # root path
109 root = info.asset_path.unique().tolist()
110 msg = f'Info must contain only 1 root path. Paths found: {root}'
111 Enforce(len(root), '==', 1, message=msg)
112 root = root[0]
113 msg = f'Directory does not exist: {root}'
114 Enforce(Path(root).is_dir(), '==', True, message=msg)
116 # info files
117 info['filepath'] = info.filepath_relative \
118 .apply(lambda x: Path(root, x).as_posix())
119 mask = info.filepath.apply(lambda x: not Path(x).is_file())
120 absent = info.loc[mask, 'filepath'].tolist()
121 msg = f'Files do not exist: {absent}'
122 Enforce(len(absent), '==', 0, message=msg)
124 # extension
125 mask = info.filepath \
126 .apply(lambda x: Path(x).suffix.lower()[1:]) \
127 .apply(lambda x: re.search(ext_regex, x, re.I) is None)
128 bad_ext = sorted(info.loc[mask, 'filepath'].tolist())
129 msg = f'Found files extensions that do not match ext_regex: {bad_ext}'
130 Enforce(len(bad_ext), '==', 0, message=msg)
132 # frame indicators
133 frame_regex = r'_(f|c)(\d+)\.' + f'({ext_regex})$'
134 mask = info.filepath.apply(lambda x: re.search(frame_regex, x) is None)
135 bad_frames = info.loc[mask, 'filepath'].tolist()
136 msg = 'Found files missing frame indicators. '
137 msg += f"File names must match '{frame_regex}'. "
138 msg += f'Invalid frames: {bad_frames}'
139 Enforce(len(bad_frames), '==', 0, message=msg)
141 # frame column
142 info['frame'] = info.filepath \
143 .apply(lambda x: re.search(frame_regex, x).group(2)).astype(int) # type: ignore
145 # gb column
146 info['gb'] = np.nan
147 if calc_file_size:
148 info['gb'] = info.filepath \
149 .apply(lambda x: os.stat(x).st_size / 10**9)
151 # loaded column
152 info['loaded'] = False
153 # ----------------------------------------------------------------------
155 # reorganize columns
156 cols = [
157 'gb', 'frame', 'asset_path', 'filepath_relative', 'filepath',
158 'loaded'
159 ]
160 cols = cols + info.drop(cols, axis=1).columns.tolist()
161 info = info[cols]
163 self._info = info # type: pd.DataFrame
164 self.data = None # type: OptArray
165 self.labels = labels
166 self.label_axis = label_axis
167 self._ext_regex = ext_regex
168 self._calc_file_size = calc_file_size
169 self._sample_gb = np.nan # type: Union[float, np.ndarray]
171 @property
172 def info(self):
173 # type: () -> pd.DataFrame
174 '''
175 Returns:
176 DataFrame: Copy of info DataFrame.
177 '''
178 return self._info.copy()
180 @property
181 def filepaths(self):
182 # type: () -> list[str]
183 '''
184 Returns:
185 list[str]: Filepaths sorted by frame.
186 '''
187 return self._info.sort_values('frame').filepath.tolist()
189 @property
190 def asset_path(self):
191 # type: () -> str
192 '''
193 Returns:
194 str: Asset path of Dataset.
195 '''
196 return self.info.loc[0, 'asset_path']
198 @property
199 def asset_name(self):
200 # type: () -> str
201 '''
202 Returns:
203 str: Asset name of Dataset.
204 '''
205 return Path(self.asset_path).name
207 @property
208 def stats(self):
209 # type: () -> pd.DataFrame
210 '''
211 Generates a table of statistics of info data.
213 Metrics include:
215 * min
216 * max
217 * mean
218 * std
219 * loaded
220 * total
222 Units include:
224 * gb
225 * frame
226 * sample
228 Returns:
229 DataFrame: Table of statistics.
230 '''
231 info = self.info
232 a = self._get_stats(info)
233 b = self._get_stats(info.loc[info.loaded]) \
234 .loc[['total']].rename(lambda x: 'loaded')
235 stats = pd.concat([a, b])
236 stats['sample'] = np.nan
238 if self.data is not None:
239 loaded = round(self.data.nbytes / 10**9, 2)
240 stats.loc['loaded', 'gb'] = loaded
242 # sample stats
243 total = info['gb'].sum() / self._sample_gb
244 stats.loc['loaded', 'sample'] = self.data.shape[0]
245 stats.loc['total', 'sample'] = total
246 stats.loc['mean', 'sample'] = total / info.shape[0]
247 stats['sample'] = stats['sample'].apply(lambda x: round(x, 0))
249 index = ['min', 'max', 'mean', 'std', 'loaded', 'total']
250 stats = stats.loc[index]
251 return stats
253 @staticmethod
254 def _get_stats(info):
255 # type: (pd.DataFrame) -> pd.DataFrame
256 '''
257 Creates table of statistics from given info DataFrame.
259 Args:
260 info (pd.DataFrame): Info DataFrame.
262 Returns:
263 pd.DataFrame: Stats DataFrame.
264 '''
265 stats = info.describe()
266 rows = ['min', 'max', 'mean', 'std', 'count']
267 stats = stats.loc[rows]
268 stats.loc['total'] = info[stats.columns].sum()
269 stats.loc['total', 'frame'] = stats.loc['count', 'frame']
270 stats.loc['mean', 'frame'] = np.nan
271 stats.loc['std', 'frame'] = np.nan
272 stats = stats.map(lambda x: round(x, 2))
273 stats.drop('count', inplace=True)
274 return stats
276 def __repr__(self):
277 # type: () -> str
278 '''
279 Returns:
280 str: Info statistics.
281 '''
282 msg = f'''
283 <Dataset>
284 ASSET_NAME: {self.asset_name}
285 ASSET_PATH: {self.asset_path}
286 STATS:
287 '''[1:]
288 msg = fict.unindent(msg, spaces=8)
289 cols = ['gb', 'frame', 'sample']
290 stats = str(self.stats[cols])
291 stats = '\n '.join(stats.split('\n'))
292 msg = msg + stats
293 return msg
295 def __len__(self):
296 # tyope: () -> int
297 '''
298 Returns:
299 int: Number of frames.
300 '''
301 return len(self._info)
303 def __getitem(self, frame):
304 # type: (int) -> Any
305 '''
306 Get data by frame.
307 Thisi is needed to avoid recursion errors when overloading __getitem__.
309 Raises:
310 IndexError: If frame is missing or multiple frames were found.
312 Returns:
313 object: Data of given frame.
314 '''
315 return self._read_file(self.get_filepath(frame))
317 def __getitem__(self, frame):
318 # type: (int) -> Any
319 '''
320 Get data by frame.
322 Raises:
323 IndexError: If frame is missing or multiple frames were found.
325 Returns:
326 object: Data of given frame.
327 '''
328 return self.__getitem(frame)
330 def get_filepath(self, frame):
331 # type: (int) -> Any
332 '''
333 Get filepath of given frame.
335 Raises:
336 IndexError: If frame is missing or multiple frames were found.
338 Returns:
339 str: Filepath of given frame.
340 '''
341 info = self._info
342 mask = info.frame == frame
343 filepaths = info.loc[mask, 'filepath'].tolist()
344 if len(filepaths) == 0:
345 raise IndexError(f'Missing frame {frame}.')
346 elif len(filepaths) > 1:
347 raise IndexError(f'Multiple frames found for {frame}.')
348 return filepaths[0]
350 def get_arrays(self, frame):
351 # type: (int) -> list[np.ndarray]
352 '''
353 Get data and convert into numpy arrays according to labels.
355 Args:
356 frame (int): Frame.
358 Raises:
359 IndexError: If frame is missing or multiple frames were found.
361 Returns:
362 list[np.ndarray]: List of arrays from the given frame.
363 '''
364 labels = self.labels # type: Any
365 if labels is None or labels == []:
366 return [self._read_file_as_array(self.get_filepath(frame))]
368 item = self.__getitem(frame)
370 # get labels
371 if not isinstance(labels, list):
372 labels = [labels]
374 # if item is numpy array return a np.split
375 if isinstance(item, np.ndarray):
376 arrays = list(np.split(item, labels, axis=self.label_axis))
378 # otherwise item is an Image with channels
379 else:
380 chan = list(filter(lambda x: x not in labels, item.channels))
381 img = item.to_bit_depth(cvd.BitDepth.FLOAT16)
382 arrays = [img[:, :, chan].data, img[:, :, labels].data]
384 # enforce shape equivalence
385 max_dim = max(*[x.ndim for x in arrays])
386 output = []
387 for array in arrays:
388 if array.ndim != max_dim:
389 ndim = list(range(max_dim - array.ndim + 1, max_dim))
390 array = np.expand_dims(array, axis=ndim)
391 output.append(array)
392 return output
394 def _read_file(self, filepath):
395 # type: (str) -> Any
396 '''
397 Read given file.
399 Args:
400 filepath (str): Filepath.
402 Raises:
403 IOError: If extension is not supported.
405 Returns:
406 object: File content.
407 '''
408 ext = Path(filepath).suffix[1:].lower()
409 if ext == 'npy':
410 return np.load(filepath)
412 formats = [x.lower() for x in cvd.ImageFormat.__members__.keys()]
413 formats += ['jpg']
414 if ext in formats:
415 return cvd.Image.read(filepath)
417 raise IOError(f'Unsupported extension: {ext}')
419 def _read_file_as_array(self, filepath):
420 # type: (str) -> np.ndarray
421 '''
422 Read file as numpy array.
424 Args:
425 filepath (str): Filepath.
427 Returns:
428 np.ndarray: Array.
429 '''
430 item = self._read_file(filepath)
432 ext = Path(filepath).suffix[1:].lower()
433 if ext == 'npy':
434 return item
435 return item.data
437 @staticmethod
438 def _resolve_limit(limit):
439 # type: (Union[int, str, None]) -> tuple[int, str]
440 '''
441 Resolves a given limit into a number of samples and limit type.
443 Args:
444 limit (str, int, None): Limit descriptor.
446 Returns:
447 tuple[int, str]: Number of samples and limit type.
448 '''
449 if isinstance(limit, int):
450 return limit, 'samples'
452 elif isinstance(limit, str):
453 return hf.parse_size(limit), 'memory'
455 return -1, 'None'
457 def load(self, limit=None, shuffle=False, reshape=True):
458 # type: (Optional[Union[str, int]], bool, bool) -> Dataset
459 '''
460 Load data from files.
462 Args:
463 limit (str or int, optional): Limit data by number of samples or
464 memory size. Default: None.
465 shuffle (bool, optional): Shuffle frames before loading.
466 Default: False.
467 reshape (bool, optional): Reshape concatenated data to incorpate
468 frames as the first dimension: (FRAME, ...). Analogous to the
469 first dimension being batch. Default: True.
471 Returns:
472 Dataset: self.
473 '''
474 self.unload()
476 # resolve limit
477 limit_, limit_type = self._resolve_limit(limit)
479 # shuffle rows
480 rows = list(self.info.iterrows())
481 if shuffle:
482 random.shuffle(rows)
484 # frame vars
485 frames = []
486 memory = 0
487 samples = 0
489 # tqdm message
490 desc = 'Loading Dataset Files'
491 if limit_type != 'None':
492 desc = f'May not total to 100% - {desc}'
494 # load frames
495 for i, row in tqdm(rows, desc=desc):
496 if limit_type == 'samples' and samples >= limit_:
497 break
498 elif limit_type == 'memory' and memory >= limit_:
499 break
501 frame = self._read_file_as_array(row.filepath)
502 if reshape:
503 frame = frame[np.newaxis, ...]
504 frames.append(frame)
506 self._info.loc[i, 'loaded'] = True
507 memory += frame.nbytes
508 samples += frame.shape[0]
510 # concatenate data
511 data = np.concatenate(frames, axis=0)
513 # limit array size by samples
514 if limit_type == 'samples':
515 data = data[:limit_]
517 # limit array size by memory
518 elif limit_type == 'memory':
519 sample_mem = data[0].nbytes
520 delta = data.nbytes - limit_
521 if delta > 0:
522 k = int(delta / sample_mem)
523 n = data.shape[0]
524 data = data[:n - k]
526 # set class members
527 self.data = data
528 self._sample_gb = data[0].nbytes / 10**9
529 return self
531 def unload(self):
532 # type: () -> Dataset
533 '''
534 Delete self.data and reset self.info.
536 Returns:
537 Dataset: self.
538 '''
539 del self.data
540 self.data = None
541 self._info['loaded'] = False
542 return self
544 def xy_split(self):
545 # type: () -> tuple[np.ndarray, np.ndarray]
546 '''
547 Split data into x and y arrays, according to self.labels as the split
548 index and self.label_axis as the split axis.
550 Raises:
551 EnforceError: If data has not been loaded.
552 EnforceError: If self.labels is not a list of a single integer.
554 Returns:
555 tuple[np.ndarray]: x and y arrays.
556 '''
557 msg = 'Data not loaded. Please call load method.'
558 Enforce(self.data, 'instance of', np.ndarray, message=msg)
560 msg = 'self.labels must be a list of a single integer. '
561 msg += f'Provided labels: {self.labels}.'
563 labels = self.labels # type: Any
564 Enforce(labels, 'instance of', list, message=msg)
565 Enforce(len(labels), '==', 1, message=msg)
566 Enforce(labels[0], 'instance of', int, message=msg)
567 # ----------------------------------------------------------------------
569 return np.split(self.data, self.labels, axis=self.label_axis) # type: ignore
571 def train_test_split(
572 self,
573 test_size=0.2, # type: float
574 limit=None, # type: OptInt
575 shuffle=True, # type: bool
576 seed=None, # type: OptInt
577 ):
578 # type: (...) -> tuple[Dataset, Dataset]
579 '''
580 Split into train and test Datasets.
582 Args:
583 test_size (float, optional): Test set size as a proportion.
584 Default: 0.2.
585 limit (int, optional): Limit the total length of train and test.
586 Default: None.
587 shuffle (bool, optional): Randomize data before splitting.
588 Default: True.
589 seed (float, optional): Seed number between 0 and 1. Default: None.
591 Returns:
592 tuple[Dataset]: Train Dataset, Test Dataset.
593 '''
594 train, test = fict.train_test_split(
595 self.info,
596 test_size=test_size, limit=limit, shuffle=shuffle, seed=seed
597 )
598 train.reset_index(drop=True, inplace=True)
599 test.reset_index(drop=True, inplace=True)
600 kwargs = dict(
601 ext_regex=self._ext_regex,
602 calc_file_size=self._calc_file_size,
603 labels=self.labels,
604 label_axis=self.label_axis
605 )
606 return Dataset(train, **kwargs), Dataset(test, **kwargs) # type: ignore