Coverage for /home/ubuntu/flatiron/python/flatiron/core/multidataset.py: 100%
53 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-08 21:55 +0000
1from typing import Any, Optional, Union # noqa F401
2from flatiron.core.types import OptInt # noqa F401
3import numpy as np # noqa F401
5from pathlib import Path
7import pandas as pd
9from flatiron.core.dataset import Dataset
10import flatiron.core.tools as fict
11# ------------------------------------------------------------------------------
14class MultiDataset:
15 '''
16 This class combines a dictionary of Dataset instances into a single dataset.
17 Datasets are merged by frame.
18 '''
19 def __init__(self, datasets):
20 # type: (dict[str, Dataset]) -> None
21 '''
22 Constructs a MultiDataset instance.
24 Args:
25 datasets (dict[str, Dataset]): Dictionary of Dataset instances.
26 '''
27 self.datasets = datasets
29 data = None # type: Any
30 for key, item in sorted(datasets.items()):
31 info = item.info
32 info['filepath'] = info.apply(
33 lambda x: Path(x.asset_path, x.filepath_relative).as_posix(),
34 axis=1
35 )
36 info = info[['frame', 'filepath']]
37 if data is None:
38 data = info
39 prev = key
40 else:
41 suffix = [f'_{prev}', f'_{key}']
42 data = pd.merge(data, info, on='frame', suffixes=suffix)
44 self._info = data # type: pd.DataFrame
46 @property
47 def info(self):
48 # type: () -> pd.DataFrame
49 '''
50 Returns:
51 DataFrame: Copy of info DataFrame.
52 '''
53 return self._info.copy()
55 def __len__(self):
56 # tyope: () -> int
57 '''
58 Returns:
59 int: Number of frames.
60 '''
61 return len(self._info)
63 def __getitem__(self, frame):
64 # type: (int) -> dict[str, Any]
65 '''
66 For each dataset, fetch data by given frame.
68 Returns:
69 dict: Dict where values are data of the given frame.
70 '''
71 return {k: v[frame] for k, v in self.datasets.items()}
73 def get_filepaths(self, frame):
74 # type: (int) -> dict[str, str]
75 '''
76 For each dataset, get filepath of given frame.
78 Returns:
79 dict: Dict where values are filepaths of the given frame.
80 '''
81 return {k: v.get_filepath(frame) for k, v in self.datasets.items()}
83 def get_arrays(self, frame):
84 # type: (int) -> dict[str, list[np.ndarray]]
85 '''
86 For each dataset, get data and convert into numpy arrays according to
87 labels.
89 Args:
90 frame (int): Frame.
92 Raises:
93 IndexError: If frame is missing or multiple frames were found.
95 Returns:
96 dict: Dict where values are lists of arrays from the given frame.
97 '''
98 return {k: v.get_arrays(frame) for k, v in self.datasets.items()}
100 def load(self, limit=None, reshape=True):
101 # type: (Optional[Union[str, int]], bool) -> MultiDataset
102 '''
103 For each dataset, load data from files.
105 Args:
106 limit (str or int, optional): Limit data by number of samples or
107 memory size. Default: None.
108 reshape (bool, optional): Reshape concatenated data to incorpate
109 frames as the first dimension: (FRAME, ...). Analogous to the
110 first dimension being batch. Default: True.
112 Returns:
113 MultiDataset: self.
114 '''
115 kwargs = dict(limit=limit, shuffle=False, reshape=reshape) # type: Any
116 [x.load(**kwargs) for x in self.datasets.values()]
117 return self
119 def unload(self):
120 # type: () -> MultiDataset
121 '''
122 For each dataset, delete self.data and reset self.info.
124 Returns:
125 MultiDataset: self.
126 '''
127 [x.unload() for x in self.datasets.values()]
128 return self
130 def xy_split(self):
131 # type: () -> dict[str, tuple[np.ndarray, np.ndarray]]
132 '''
133 For each dataset, split data into x and y arrays, according to
134 self.labels as the split index and self.label_axis as the split axis.
136 Raises:
137 EnforceError: If data has not been loaded.
138 EnforceError: If self.labels is not a list of a single integer.
140 Returns:
141 dict: Dict where values are x and y arrays.
142 '''
143 return {k: v.xy_split() for k, v in self.datasets.items()}
145 def train_test_split(
146 self,
147 test_size=0.2, # type: float
148 limit=None, # type: OptInt
149 shuffle=True, # type: bool
150 seed=None, # type: OptInt
151 ):
152 # type: (...) -> tuple[MultiDataset, MultiDataset]
153 '''
154 Split into train and test MultiDatasets.
156 Args:
157 test_size (float, optional): Test set size as a proportion.
158 Default: 0.2.
159 limit (int, optional): Limit the total length of train and test.
160 Default: None.
161 shuffle (bool, optional): Randomize data before splitting.
162 Default: True.
163 seed (float, optional): Seed number between 0 and 1. Default: None.
165 Returns:
166 tuple[MultiDataset]: Train MultiDataset, Test MultiDataset.
167 '''
168 items = fict.train_test_split(
169 self.info,
170 test_size=test_size, limit=limit, shuffle=shuffle, seed=seed
171 )
173 msets = dict(train={}, test={}) # type: Any
174 for key, val in zip(['train', 'test'], items):
175 frames = val.frame.tolist()
176 for name, dset in self.datasets.items():
177 info = dset._info
178 mask = info.frame.apply(lambda x: x in frames)
179 info = info[mask].copy()
180 info.reset_index(drop=True, inplace=True)
181 msets[key][name] = Dataset(
182 info=info,
183 ext_regex=dset._ext_regex,
184 calc_file_size=dset._calc_file_size,
185 labels=dset.labels,
186 label_axis=dset.label_axis,
187 )
189 return MultiDataset(msets['train']), MultiDataset(msets['test'])