Coverage for /home/ubuntu/flatiron/python/flatiron/core/multidataset.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-08 21:55 +0000

1from typing import Any, Optional, Union # noqa F401 

2from flatiron.core.types import OptInt # noqa F401 

3import numpy as np # noqa F401 

4 

5from pathlib import Path 

6 

7import pandas as pd 

8 

9from flatiron.core.dataset import Dataset 

10import flatiron.core.tools as fict 

11# ------------------------------------------------------------------------------ 

12 

13 

14class MultiDataset: 

15 ''' 

16 This class combines a dictionary of Dataset instances into a single dataset. 

17 Datasets are merged by frame. 

18 ''' 

19 def __init__(self, datasets): 

20 # type: (dict[str, Dataset]) -> None 

21 ''' 

22 Constructs a MultiDataset instance. 

23 

24 Args: 

25 datasets (dict[str, Dataset]): Dictionary of Dataset instances. 

26 ''' 

27 self.datasets = datasets 

28 

29 data = None # type: Any 

30 for key, item in sorted(datasets.items()): 

31 info = item.info 

32 info['filepath'] = info.apply( 

33 lambda x: Path(x.asset_path, x.filepath_relative).as_posix(), 

34 axis=1 

35 ) 

36 info = info[['frame', 'filepath']] 

37 if data is None: 

38 data = info 

39 prev = key 

40 else: 

41 suffix = [f'_{prev}', f'_{key}'] 

42 data = pd.merge(data, info, on='frame', suffixes=suffix) 

43 

44 self._info = data # type: pd.DataFrame 

45 

46 @property 

47 def info(self): 

48 # type: () -> pd.DataFrame 

49 ''' 

50 Returns: 

51 DataFrame: Copy of info DataFrame. 

52 ''' 

53 return self._info.copy() 

54 

55 def __len__(self): 

56 # tyope: () -> int 

57 ''' 

58 Returns: 

59 int: Number of frames. 

60 ''' 

61 return len(self._info) 

62 

63 def __getitem__(self, frame): 

64 # type: (int) -> dict[str, Any] 

65 ''' 

66 For each dataset, fetch data by given frame. 

67 

68 Returns: 

69 dict: Dict where values are data of the given frame. 

70 ''' 

71 return {k: v[frame] for k, v in self.datasets.items()} 

72 

73 def get_filepaths(self, frame): 

74 # type: (int) -> dict[str, str] 

75 ''' 

76 For each dataset, get filepath of given frame. 

77 

78 Returns: 

79 dict: Dict where values are filepaths of the given frame. 

80 ''' 

81 return {k: v.get_filepath(frame) for k, v in self.datasets.items()} 

82 

83 def get_arrays(self, frame): 

84 # type: (int) -> dict[str, list[np.ndarray]] 

85 ''' 

86 For each dataset, get data and convert into numpy arrays according to 

87 labels. 

88 

89 Args: 

90 frame (int): Frame. 

91 

92 Raises: 

93 IndexError: If frame is missing or multiple frames were found. 

94 

95 Returns: 

96 dict: Dict where values are lists of arrays from the given frame. 

97 ''' 

98 return {k: v.get_arrays(frame) for k, v in self.datasets.items()} 

99 

100 def load(self, limit=None, reshape=True): 

101 # type: (Optional[Union[str, int]], bool) -> MultiDataset 

102 ''' 

103 For each dataset, load data from files. 

104 

105 Args: 

106 limit (str or int, optional): Limit data by number of samples or 

107 memory size. Default: None. 

108 reshape (bool, optional): Reshape concatenated data to incorpate 

109 frames as the first dimension: (FRAME, ...). Analogous to the 

110 first dimension being batch. Default: True. 

111 

112 Returns: 

113 MultiDataset: self. 

114 ''' 

115 kwargs = dict(limit=limit, shuffle=False, reshape=reshape) # type: Any 

116 [x.load(**kwargs) for x in self.datasets.values()] 

117 return self 

118 

119 def unload(self): 

120 # type: () -> MultiDataset 

121 ''' 

122 For each dataset, delete self.data and reset self.info. 

123 

124 Returns: 

125 MultiDataset: self. 

126 ''' 

127 [x.unload() for x in self.datasets.values()] 

128 return self 

129 

130 def xy_split(self): 

131 # type: () -> dict[str, tuple[np.ndarray, np.ndarray]] 

132 ''' 

133 For each dataset, split data into x and y arrays, according to 

134 self.labels as the split index and self.label_axis as the split axis. 

135 

136 Raises: 

137 EnforceError: If data has not been loaded. 

138 EnforceError: If self.labels is not a list of a single integer. 

139 

140 Returns: 

141 dict: Dict where values are x and y arrays. 

142 ''' 

143 return {k: v.xy_split() for k, v in self.datasets.items()} 

144 

145 def train_test_split( 

146 self, 

147 test_size=0.2, # type: float 

148 limit=None, # type: OptInt 

149 shuffle=True, # type: bool 

150 seed=None, # type: OptInt 

151 ): 

152 # type: (...) -> tuple[MultiDataset, MultiDataset] 

153 ''' 

154 Split into train and test MultiDatasets. 

155 

156 Args: 

157 test_size (float, optional): Test set size as a proportion. 

158 Default: 0.2. 

159 limit (int, optional): Limit the total length of train and test. 

160 Default: None. 

161 shuffle (bool, optional): Randomize data before splitting. 

162 Default: True. 

163 seed (float, optional): Seed number between 0 and 1. Default: None. 

164 

165 Returns: 

166 tuple[MultiDataset]: Train MultiDataset, Test MultiDataset. 

167 ''' 

168 items = fict.train_test_split( 

169 self.info, 

170 test_size=test_size, limit=limit, shuffle=shuffle, seed=seed 

171 ) 

172 

173 msets = dict(train={}, test={}) # type: Any 

174 for key, val in zip(['train', 'test'], items): 

175 frames = val.frame.tolist() 

176 for name, dset in self.datasets.items(): 

177 info = dset._info 

178 mask = info.frame.apply(lambda x: x in frames) 

179 info = info[mask].copy() 

180 info.reset_index(drop=True, inplace=True) 

181 msets[key][name] = Dataset( 

182 info=info, 

183 ext_regex=dset._ext_regex, 

184 calc_file_size=dset._calc_file_size, 

185 labels=dset.labels, 

186 label_axis=dset.label_axis, 

187 ) 

188 

189 return MultiDataset(msets['train']), MultiDataset(msets['test'])