Coverage for /home/ubuntu/rolling-pin/python/rolling_pin/repo_etl.py: 100%

208 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-13 19:35 +0000

1from typing import Any, Dict, Iterator, List, Optional, Union # noqa: F401 

2from IPython.display import HTML, Image # noqa: F401 

3 

4from itertools import chain 

5from pathlib import Path 

6import os 

7import re 

8 

9from pandas import DataFrame, Series 

10import lunchbox.tools as lbt 

11import networkx 

12import numpy as np 

13import pandas as pd 

14 

15import rolling_pin.tools as rpt 

16# ------------------------------------------------------------------------------ 

17 

18''' 

19Contains the RepoETL class, which is used for converted python repository module 

20dependencies into a directed graph. 

21''' 

22 

23 

24class RepoETL(): 

25 ''' 

26 RepoETL is a class for extracting 1st order dependencies of modules within a 

27 given repository. This information is stored internally as a DataFrame and 

28 can be rendered as networkx, pydot or SVG graphs. 

29 ''' 

30 def __init__( 

31 self, 

32 root, 

33 include_regex=r'.*\.py$', 

34 exclude_regex=r'(__init__|test_|_test|mock_)\.py$', 

35 ): 

36 # type: (Union[str, Path], str, str) -> None 

37 r''' 

38 Construct RepoETL instance. 

39 

40 Args: 

41 root (str or Path): Full path to repository root directory. 

42 include_regex (str, optional): Files to be included in recursive 

43 directy search. Default: '.*\.py$'. 

44 exclude_regex (str, optional): Files to be excluded in recursive 

45 directy search. Default: '(__init__|test_|_test|mock_)\.py$'. 

46 

47 Raises: 

48 ValueError: If include or exclude regex does not end in '\.py$'. 

49 ''' 

50 self._root = root # type: Union[str, Path] 

51 self._data = self._get_data(root, include_regex, exclude_regex) # type: DataFrame 

52 

53 @staticmethod 

54 def _get_imports(fullpath): 

55 # type: (Union[str, Path]) -> List[str] 

56 ''' 

57 Get's import statements from a given python module. 

58 

59 Args: 

60 fullpath (str or Path): Path to python module. 

61 

62 Returns: 

63 list(str): List of imported modules. 

64 ''' 

65 with open(fullpath) as f: 

66 data = f.readlines() # type: Union[List, Iterator] 

67 data = map(lambda x: x.strip('\n'), data) 

68 data = filter(lambda x: re.search('^import|^from', x), data) 

69 data = map(lambda x: re.sub('from (.*?) .*', '\\1', x), data) 

70 data = map(lambda x: re.sub(' as .*', '', x), data) 

71 data = map(lambda x: re.sub(' *#.*', '', x), data) 

72 data = map(lambda x: re.sub('import ', '', x), data) 

73 data = filter(lambda x: not lbt.is_standard_module(x), data) 

74 return list(data) 

75 

76 @staticmethod 

77 def _get_data( 

78 root, 

79 include_regex=r'.*\.py$', 

80 exclude_regex=r'(__init__|_test)\.py$', 

81 ): 

82 # type: (Union[str, Path], str, str) -> DataFrame 

83 r''' 

84 Recursively aggregates and filters all the files found with a given 

85 directory into a DataFrame. Data is used to create directed graphs. 

86 

87 DataFrame has these columns: 

88 

89 * node_name - name of node 

90 * node_type - type of node, can be [module, subpackage, library] 

91 * x - node's x coordinate 

92 * y - node's y coordinate 

93 * dependencies - parent nodes 

94 * subpackages - parent nodes of type subpackage 

95 * fullpath - fullpath to the module a node represents 

96 

97 Args: 

98 root (str or Path): Root directory to be searched. 

99 include_regex (str, optional): Files to be included in recursive 

100 directy search. Default: '.*\.py$'. 

101 exclude_regex (str, optional): Files to be excluded in recursive 

102 directy search. Default: '(__init__|_test)\.py$'. 

103 

104 Raises: 

105 ValueError: If include or exclude regex does not end in '\.py$'. 

106 FileNotFoundError: If no files are found after filtering. 

107 

108 Returns: 

109 DataFrame: DataFrame of file information. 

110 ''' 

111 root = Path(root).as_posix() 

112 files = rpt.list_all_files(root) # type: Union[Iterator, List] 

113 if include_regex != '': 

114 if not include_regex.endswith(r'\.py$'): 

115 msg = f"Invalid include_regex: '{include_regex}'. " 

116 msg += r"Does not end in '.py$'." 

117 raise ValueError(msg) 

118 

119 files = filter( 

120 lambda x: re.search(include_regex, x.absolute().as_posix()), 

121 files 

122 ) 

123 if exclude_regex != '': 

124 files = filter( 

125 lambda x: not re.search(exclude_regex, x.absolute().as_posix()), 

126 files 

127 ) 

128 

129 files = list(files) 

130 if len(files) == 0: 

131 msg = f'No files found after filters in directory: {root}.' 

132 raise FileNotFoundError(msg) 

133 

134 # buid DataFrame of nodes and imported dependencies 

135 data = DataFrame() 

136 data['fullpath'] = files 

137 data.fullpath = data.fullpath.apply(lambda x: x.absolute().as_posix()) 

138 

139 data['node_name'] = data.fullpath\ 

140 .apply(lambda x: re.sub(root, '', x))\ 

141 .apply(lambda x: re.sub(r'\.py$', '', x))\ 

142 .apply(lambda x: re.sub('^/', '', x))\ 

143 .apply(lambda x: re.sub('/', '.', x)) 

144 

145 data['subpackages'] = data.node_name\ 

146 .apply(lambda x: rpt.get_parent_fields(x, '.')).apply(lbt.get_ordered_unique) 

147 data.subpackages = data.subpackages\ 

148 .apply(lambda x: list(filter(lambda y: y != '', x))) 

149 

150 data['dependencies'] = data.fullpath\ 

151 .apply(RepoETL._get_imports).apply(lbt.get_ordered_unique) 

152 data.dependencies += data.node_name\ 

153 .apply(lambda x: ['.'.join(x.split('.')[:-1])]) 

154 data.dependencies = data.dependencies\ 

155 .apply(lambda x: list(filter(lambda y: y != '', x))) 

156 

157 data['node_type'] = 'module' 

158 

159 # add subpackages as nodes 

160 pkgs = set(chain(*data.subpackages.tolist())) # type: Any 

161 pkgs = pkgs.difference(data.node_name.tolist()) 

162 pkgs = sorted(list(pkgs)) 

163 pkgs = Series(pkgs)\ 

164 .apply( 

165 lambda x: dict( 

166 node_name=x, 

167 node_type='subpackage', 

168 dependencies=rpt.get_parent_fields(x, '.'), 

169 subpackages=rpt.get_parent_fields(x, '.'), 

170 )).tolist() 

171 pkgs = DataFrame(pkgs) 

172 data = pd.concat([data, pkgs], ignore_index=True, sort=True) 

173 

174 # add library dependencies as nodes 

175 libs = set(chain(*data.dependencies.tolist())) # type: Any 

176 libs = libs.difference(data.node_name.tolist()) 

177 libs = sorted(list(libs)) 

178 libs = Series(libs)\ 

179 .apply( 

180 lambda x: dict( 

181 node_name=x, 

182 node_type='library', 

183 dependencies=[], 

184 subpackages=[], 

185 )).tolist() 

186 libs = DataFrame(libs) 

187 data = pd.concat([data, libs], ignore_index=True, sort=True) 

188 

189 data.drop_duplicates('node_name', inplace=True) 

190 data.reset_index(drop=True, inplace=True) 

191 

192 # define node coordinates 

193 data['x'] = 0 

194 data['y'] = 0 

195 data = RepoETL._calculate_coordinates(data) 

196 data = RepoETL._anneal_coordinate(data, 'x', 'y') 

197 data = RepoETL._center_coordinate(data, 'x', 'y') 

198 

199 data.sort_values('fullpath', inplace=True) 

200 data.reset_index(drop=True, inplace=True) 

201 

202 cols = [ 

203 'node_name', 

204 'node_type', 

205 'x', 

206 'y', 

207 'dependencies', 

208 'subpackages', 

209 'fullpath', 

210 ] 

211 data = data[cols] 

212 return data 

213 

214 @staticmethod 

215 def _calculate_coordinates(data): 

216 # type: (DataFrame) -> DataFrame 

217 ''' 

218 Calculate inital x, y coordinates for each node in given DataFrame. 

219 Node are startified by type along the y axis. 

220 

221 Args: 

222 DataFrame: DataFrame of nodes. 

223 

224 Returns: 

225 DataFrame: DataFrame with x and y coordinate columns. 

226 ''' 

227 # set initial node coordinates 

228 data['y'] = 0 

229 for item in ['module', 'subpackage', 'library']: 

230 mask = data.node_type == item 

231 n = data[mask].shape[0] 

232 

233 index = data[mask].index 

234 data.loc[index, 'x'] = list(range(n)) 

235 

236 # move non-library nodes down the y axis according to how nested 

237 # they are 

238 if item != 'library': 

239 data.loc[index, 'y'] = data.loc[index, 'node_name']\ 

240 .apply(lambda x: len(x.split('.'))) 

241 

242 # move all module nodes beneath supackage nodes on the y axis 

243 max_ = data[data.node_type == 'subpackage'].y.max() 

244 index = data[data.node_type == 'module'].index 

245 data.loc[index, 'y'] += max_ 

246 data.loc[index, 'y'] += data.loc[index, 'subpackages'].apply(len) 

247 

248 # reverse y axis 

249 max_ = data.y.max() 

250 data.y = -1 * data.y + max_ 

251 

252 return data 

253 

254 @staticmethod 

255 def _anneal_coordinate(data, anneal_axis='x', pin_axis='y', iterations=10): 

256 # type: (DataFrame, str, str, int) -> DataFrame 

257 ''' 

258 Iteratively align nodes in the anneal axis according to the mean 

259 position of their connected nodes. Node anneal coordinates are rectified 

260 at the end of each iteration according to a pin axis, so that they do 

261 not overlap. This mean that they are sorted at each level of the pin 

262 axis. 

263 

264 Args: 

265 data (DataFrame): DataFrame with x column. 

266 anneal_axis (str, optional): Coordinate column to be annealed. 

267 Default: 'x'. 

268 pin_axis (str, optional): Coordinate column to be held constant. 

269 Default: 'y'. 

270 iterations (int, optional): Number of times to update x coordinates. 

271 Default: 10. 

272 

273 Returns: 

274 DataFrame: DataFrame with annealed anneal axis coordinates. 

275 ''' 

276 x = anneal_axis 

277 y = pin_axis 

278 data[x] = data[x].astype(float) 

279 data[y] = data[y].astype(float) 

280 for iteration in range(iterations): 

281 # create directed graph from data 

282 graph = RepoETL._to_networkx_graph(data) 

283 

284 # reverse connectivity every other iteration 

285 if iteration % 2 == 0: 

286 graph = graph.reverse() 

287 

288 # get mean coordinate of each node in directed graph 

289 for name in graph.nodes: 

290 tree = networkx.bfs_tree(graph, name) 

291 mu = np.mean([graph.nodes[n][x] for n in tree]) 

292 graph.nodes[name][x] = mu 

293 

294 # update data coordinate column 

295 for node in graph.nodes: 

296 mask = data[data.node_name == node].index 

297 data.loc[mask, x] = graph.nodes[node][x] 

298 

299 # rectify data coordinate column, so that no two nodes overlap 

300 data.sort_values(x, inplace=True) 

301 for yi in data[y].unique(): 

302 mask = data[data[y] == yi].index 

303 values = data.loc[mask, x].tolist() 

304 values = list(range(len(values))) 

305 data.loc[mask, x] = values 

306 

307 return data 

308 

309 @staticmethod 

310 def _center_coordinate(data, center_axis='x', pin_axis='y'): 

311 # (DataFrame, str, str) -> DataFrame 

312 ''' 

313 Sorted center_axis coordinates at each level of the pin axis. 

314 

315 Args: 

316 data (DataFrame): DataFrame with x column. 

317 anneal_column (str, optional): Coordinate column to be annealed. 

318 Default: 'x'. 

319 pin_axis (str, optional): Coordinate column to be held constant. 

320 Default: 'y'. 

321 iterations (int, optional): Number of times to update x coordinates. 

322 Default: 10. 

323 

324 Returns: 

325 DataFrame: DataFrame with centered center axis coordinates. 

326 ''' 

327 x = center_axis 

328 y = pin_axis 

329 max_ = data[x].max() 

330 for yi in data[y].unique(): 

331 mask = data[data[y] == yi].index 

332 l_max = data.loc[mask, x].max() 

333 delta = max_ - l_max 

334 data.loc[mask, x] += (delta / 2) 

335 return data 

336 

337 @staticmethod 

338 def _to_networkx_graph(data, escape_chars=False): 

339 # (DataFrame, bool) -> networkx.DiGraph 

340 ''' 

341 Converts given DataFrame into networkx directed graph. 

342 

343 Args: 

344 data (DataFrame): DataFrame of nodes. 

345 escape_chars (bool, optional): Escape special characters. Used to 

346 avoid dot file errors. Default: False. 

347 

348 Returns: 

349 networkx.DiGraph: Graph of nodes. 

350 ''' 

351 # escape periods for dot file interpolation 

352 if escape_chars: 

353 data = data.copy() 

354 data.node_name = data.node_name \ 

355 .fillna('') \ 

356 .apply(lambda x: re.sub(r'\.', '\\.', x)) 

357 data.dependencies = data.dependencies \ 

358 .apply(lambda x: [re.sub(r'\.', '\\.', y) for y in x]) 

359 

360 graph = networkx.DiGraph() 

361 data.apply( 

362 lambda x: graph.add_node( 

363 x.node_name, 

364 **{k: getattr(x, k) for k in x.index} 

365 ), 

366 axis=1 

367 ) 

368 

369 data.apply( 

370 lambda x: [graph.add_edge(p, x.node_name) for p in x.dependencies], 

371 axis=1 

372 ) 

373 return graph 

374 

375 def to_networkx_graph(self): 

376 # () -> networkx.DiGraph 

377 ''' 

378 Converts internal data into networkx directed graph. 

379 

380 Returns: 

381 networkx.DiGraph: Graph of nodes. 

382 ''' 

383 return RepoETL._to_networkx_graph(self._data) 

384 

385 def to_dot_graph(self, orient='tb', orthogonal_edges=False, color_scheme=None): 

386 # (str, bool, Optional[Dict[str, str]]) -> pydot.Dot 

387 ''' 

388 Converts internal data into pydot graph. 

389 

390 Args: 

391 orient (str, optional): Graph layout orientation. Default: tb. 

392 Options include: 

393 

394 * tb - top to bottom 

395 * bt - bottom to top 

396 * lr - left to right 

397 * rl - right to left 

398 orthogonal_edges (bool, optional): Whether graph edges should have 

399 non-right angles. Default: False. 

400 color_scheme: (dict, optional): Color scheme to be applied to graph. 

401 Default: rolling_pin.tools.COLOR_SCHEME 

402 

403 Raises: 

404 ValueError: If orient is invalid. 

405 

406 Returns: 

407 pydot.Dot: Dot graph of nodes. 

408 ''' 

409 orient = orient.lower() 

410 orientations = ['tb', 'bt', 'lr', 'rl'] 

411 if orient not in orientations: 

412 msg = f'Invalid orient value. {orient} not in {orientations}.' 

413 raise ValueError(msg) 

414 

415 # set color scheme of graph 

416 if color_scheme is None: 

417 color_scheme = rpt.COLOR_SCHEME 

418 

419 # create dot graph 

420 graph = self._to_networkx_graph(self._data, escape_chars=True) 

421 dot = networkx.drawing.nx_pydot.to_pydot(graph) 

422 

423 # set layout orientation 

424 dot.set_rankdir(orient.upper()) 

425 

426 # set graph background color 

427 dot.set_bgcolor(color_scheme['background']) 

428 

429 # set edge draw type 

430 if orthogonal_edges: 

431 dot.set_splines('ortho') 

432 

433 # set draw parameters for each node in graph 

434 for node in dot.get_nodes(): 

435 # set node shape, color and font attributes 

436 node.set_shape('rect') 

437 node.set_style('filled') 

438 node.set_color(color_scheme['node']) 

439 node.set_fillcolor(color_scheme['node']) 

440 node.set_fontname('Courier') 

441 

442 nx_node = re.sub('"', '', node.get_name()) 

443 nx_node = graph.nodes[nx_node] 

444 

445 # if networkx node has no attributes skip it 

446 # this should not ever occur but might 

447 if nx_node == {}: 

448 continue # pragma: no cover 

449 

450 # set node x, y coordinates 

451 node.set_pos(f"{nx_node['x']},{nx_node['y']}!") 

452 

453 # vary node font color by noe type 

454 if nx_node['node_type'] == 'library': 

455 node.set_fontcolor(color_scheme['node_library_font']) 

456 elif nx_node['node_type'] == 'subpackage': 

457 node.set_fontcolor(color_scheme['node_subpackage_font']) 

458 else: 

459 node.set_fontcolor(color_scheme['node_module_font']) 

460 

461 # set draw parameters for each edge in graph 

462 for edge in dot.get_edges(): 

463 # get networkx source node of edge 

464 nx_node = dot.get_node(edge.get_source()) 

465 nx_node = nx_node[0].get_name() 

466 nx_node = re.sub('"', '', nx_node) 

467 nx_node = graph.nodes[nx_node] 

468 

469 # if networkx source node has no attributes skip it 

470 # this should not ever occur but might 

471 if nx_node == {}: 

472 continue # pragma: no cover 

473 

474 # vary edge color by its source node type 

475 if nx_node['node_type'] == 'library': 

476 edge.set_color(color_scheme['edge_library']) 

477 elif nx_node['node_type'] == 'subpackage': 

478 edge.set_color(color_scheme['edge_subpackage']) 

479 else: 

480 # this line is actually covered by pytest doesn't think so 

481 edge.set_color(color_scheme['edge_module']) # pragma: no cover 

482 

483 return dot 

484 

485 def to_dataframe(self): 

486 # type: () -> DataFrame 

487 ''' 

488 Retruns: 

489 DataFrame: DataFrame of nodes representing repo modules. 

490 ''' 

491 return self._data.copy() 

492 

493 def to_html( 

494 self, 

495 layout='dot', 

496 orthogonal_edges=False, 

497 color_scheme=None, 

498 as_png=False 

499 ): 

500 # type: (str, bool, Optional[Dict[str, str]], bool) -> Union[HTML, Image] 

501 ''' 

502 For use in inline rendering of graph data in Jupyter Lab. 

503 

504 Args: 

505 layout (str, optional): Graph layout style. 

506 Options include: circo, dot, fdp, neato, sfdp, twopi. 

507 Default: dot. 

508 orthogonal_edges (bool, optional): Whether graph edges should have 

509 non-right angles. Default: False. 

510 color_scheme: (dict, optional): Color scheme to be applied to graph. 

511 Default: rolling_pin.tools.COLOR_SCHEME 

512 as_png (bool, optional): Display graph as a PNG image instead of 

513 SVG. Useful for display on Github. Default: False. 

514 

515 Returns: 

516 IPython.display.HTML: HTML object for inline display. 

517 ''' 

518 if color_scheme is None: 

519 color_scheme = rpt.COLOR_SCHEME 

520 

521 dot = self.to_dot_graph( 

522 orthogonal_edges=orthogonal_edges, 

523 color_scheme=color_scheme, 

524 ) 

525 return rpt.dot_to_html(dot, layout=layout, as_png=as_png) 

526 

527 def write( 

528 self, 

529 fullpath, 

530 layout='dot', 

531 orient='tb', 

532 orthogonal_edges=False, 

533 color_scheme=None 

534 ): 

535 # type: (Union[str, Path], str, str, bool, Optional[Dict[str, str]]) -> RepoETL 

536 ''' 

537 Writes internal data to a given filepath. 

538 Formats supported: svg, dot, png, json. 

539 

540 Args: 

541 fulllpath (str or Path): File to be written to. 

542 layout (str, optional): Graph layout style. 

543 Options include: circo, dot, fdp, neato, sfdp, twopi. Default: dot. 

544 orient (str, optional): Graph layout orientation. Default: tb. 

545 Options include: 

546 

547 * tb - top to bottom 

548 * bt - bottom to top 

549 * lr - left to right 

550 * rl - right to left 

551 orthogonal_edges (bool, optional): Whether graph edges should have 

552 non-right angles. Default: False. 

553 color_scheme: (dict, optional): Color scheme to be applied to graph. 

554 Default: rolling_pin.tools.COLOR_SCHEME 

555 

556 Raises: 

557 ValueError: If invalid file extension given. 

558 

559 Returns: 

560 RepoETL: Self. 

561 ''' 

562 if isinstance(fullpath, Path): 

563 fullpath = fullpath.absolute().as_posix() 

564 

565 _, ext = os.path.splitext(fullpath) 

566 ext = re.sub(r'^\.', '', ext) 

567 if re.search('^json$', ext, re.I): 

568 self._data.to_json(fullpath, orient='records') 

569 return self 

570 

571 if color_scheme is None: 

572 color_scheme = rpt.COLOR_SCHEME 

573 

574 graph = self.to_dot_graph( 

575 orient=orient, 

576 orthogonal_edges=orthogonal_edges, 

577 color_scheme=color_scheme, 

578 ) 

579 try: 

580 rpt.write_dot_graph(graph, fullpath, layout=layout,) 

581 except ValueError: 

582 msg = f'Invalid extension found: {ext}. ' 

583 msg += 'Valid extensions include: svg, dot, png, json.' 

584 raise ValueError(msg) 

585 return self