Coverage for /home/ubuntu/rolling-pin/python/rolling_pin/repo_etl.py: 100%

202 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-11-15 00:43 +0000

1from typing import Any, Dict, Iterator, List, Optional, Union # noqa: F401 

2from IPython.display import HTML, Image # noqa: F401 

3 

4from itertools import chain 

5from pathlib import Path 

6import os 

7import re 

8 

9from pandas import DataFrame, Series 

10import lunchbox.tools as lbt 

11import networkx 

12import numpy as np 

13import pandas as pd 

14 

15import rolling_pin.tools as rpt 

16# ------------------------------------------------------------------------------ 

17 

18''' 

19Contains the RepoETL class, which is used for converted python repository module 

20dependencies into a directed graph. 

21''' 

22 

23 

24class RepoETL(): 

25 ''' 

26 RepoETL is a class for extracting 1st order dependencies of modules within a 

27 given repository. This information is stored internally as a DataFrame and 

28 can be rendered as networkx, pydot or SVG graphs. 

29 ''' 

30 def __init__( 

31 self, 

32 root, 

33 include_regex=r'.*\.py$', 

34 exclude_regex=r'(__init__|test_|_test|mock_)\.py$', 

35 ): 

36 # type: (Union[str, Path], str, str) -> None 

37 r''' 

38 Construct RepoETL instance. 

39 

40 Args: 

41 root (str or Path): Full path to repository root directory. 

42 include_regex (str, optional): Files to be included in recursive 

43 directy search. Default: '.*\.py$'. 

44 exclude_regex (str, optional): Files to be excluded in recursive 

45 directy search. Default: '(__init__|test_|_test|mock_)\.py$'. 

46 

47 Raises: 

48 ValueError: If include or exclude regex does not end in '\.py$'. 

49 ''' 

50 self._root = root # type: Union[str, Path] 

51 self._data = self._get_data(root, include_regex, exclude_regex) # type: DataFrame 

52 

53 @staticmethod 

54 def _get_imports(fullpath): 

55 # type: (Union[str, Path]) -> List[str] 

56 ''' 

57 Get's import statements from a given python module. 

58 

59 Args: 

60 fullpath (str or Path): Path to python module. 

61 

62 Returns: 

63 list(str): List of imported modules. 

64 ''' 

65 with open(fullpath) as f: 

66 data = f.readlines() # type: Union[List, Iterator] 

67 data = map(lambda x: x.strip('\n'), data) 

68 data = filter(lambda x: re.search('^import|^from', x), data) 

69 data = map(lambda x: re.sub('from (.*?) .*', '\\1', x), data) 

70 data = map(lambda x: re.sub(' as .*', '', x), data) 

71 data = map(lambda x: re.sub(' *#.*', '', x), data) 

72 data = map(lambda x: re.sub('import ', '', x), data) 

73 data = filter(lambda x: not lbt.is_standard_module(x), data) 

74 return list(data) 

75 

76 @staticmethod 

77 def _get_data( 

78 root, 

79 include_regex=r'.*\.py$', 

80 exclude_regex=r'(__init__|_test)\.py$', 

81 ): 

82 # type: (Union[str, Path], str, str) -> DataFrame 

83 r''' 

84 Recursively aggregates and filters all the files found with a given 

85 directory into a DataFrame. Data is used to create directed graphs. 

86 

87 DataFrame has these columns: 

88 

89 * node_name - name of node 

90 * node_type - type of node, can be [module, subpackage, library] 

91 * x - node's x coordinate 

92 * y - node's y coordinate 

93 * dependencies - parent nodes 

94 * subpackages - parent nodes of type subpackage 

95 * fullpath - fullpath to the module a node represents 

96 

97 Args: 

98 root (str or Path): Root directory to be searched. 

99 include_regex (str, optional): Files to be included in recursive 

100 directy search. Default: '.*\.py$'. 

101 exclude_regex (str, optional): Files to be excluded in recursive 

102 directy search. Default: '(__init__|_test)\.py$'. 

103 

104 Raises: 

105 ValueError: If include or exclude regex does not end in '\.py$'. 

106 FileNotFoundError: If no files are found after filtering. 

107 

108 Returns: 

109 DataFrame: DataFrame of file information. 

110 ''' 

111 root = Path(root).as_posix() 

112 files = rpt.list_all_files(root) # type: Union[Iterator, List] 

113 if include_regex != '': 

114 if not include_regex.endswith(r'\.py$'): 

115 msg = f"Invalid include_regex: '{include_regex}'. " 

116 msg += r"Does not end in '.py$'." 

117 raise ValueError(msg) 

118 

119 files = filter( 

120 lambda x: re.search(include_regex, x.absolute().as_posix()), 

121 files 

122 ) 

123 if exclude_regex != '': 

124 files = filter( 

125 lambda x: not re.search(exclude_regex, x.absolute().as_posix()), 

126 files 

127 ) 

128 

129 files = list(files) 

130 if len(files) == 0: 

131 msg = f'No files found after filters in directory: {root}.' 

132 raise FileNotFoundError(msg) 

133 

134 # buid DataFrame of nodes and imported dependencies 

135 data = DataFrame() 

136 data['fullpath'] = files 

137 data.fullpath = data.fullpath.apply(lambda x: x.absolute().as_posix()) 

138 

139 data['node_name'] = data.fullpath\ 

140 .apply(lambda x: re.sub(root, '', x))\ 

141 .apply(lambda x: re.sub(r'\.py$', '', x))\ 

142 .apply(lambda x: re.sub('^/', '', x))\ 

143 .apply(lambda x: re.sub('/', '.', x)) 

144 

145 data['subpackages'] = data.node_name\ 

146 .apply(lambda x: rpt.get_parent_fields(x, '.')).apply(lbt.get_ordered_unique) 

147 data.subpackages = data.subpackages\ 

148 .apply(lambda x: list(filter(lambda y: y != '', x))) 

149 

150 data['dependencies'] = data.fullpath\ 

151 .apply(RepoETL._get_imports).apply(lbt.get_ordered_unique) 

152 data.dependencies += data.node_name\ 

153 .apply(lambda x: ['.'.join(x.split('.')[:-1])]) 

154 data.dependencies = data.dependencies\ 

155 .apply(lambda x: list(filter(lambda y: y != '', x))) 

156 

157 data['node_type'] = 'module' 

158 

159 # add subpackages as nodes 

160 pkgs = set(chain(*data.subpackages.tolist())) # type: Any 

161 pkgs = pkgs.difference(data.node_name.tolist()) 

162 pkgs = sorted(list(pkgs)) 

163 pkgs = Series(pkgs)\ 

164 .apply( 

165 lambda x: dict( 

166 node_name=x, 

167 node_type='subpackage', 

168 dependencies=rpt.get_parent_fields(x, '.'), 

169 subpackages=rpt.get_parent_fields(x, '.'), 

170 )).tolist() 

171 pkgs = DataFrame(pkgs) 

172 data = pd.concat([data, pkgs], ignore_index=True, sort=True) 

173 

174 # add library dependencies as nodes 

175 libs = set(chain(*data.dependencies.tolist())) # type: Any 

176 libs = libs.difference(data.node_name.tolist()) 

177 libs = sorted(list(libs)) 

178 libs = Series(libs)\ 

179 .apply( 

180 lambda x: dict( 

181 node_name=x, 

182 node_type='library', 

183 dependencies=[], 

184 subpackages=[], 

185 )).tolist() 

186 libs = DataFrame(libs) 

187 data = pd.concat([data, libs], ignore_index=True, sort=True) 

188 

189 data.drop_duplicates('node_name', inplace=True) 

190 data.reset_index(drop=True, inplace=True) 

191 

192 # define node coordinates 

193 data['x'] = 0 

194 data['y'] = 0 

195 data = RepoETL._calculate_coordinates(data) 

196 data = RepoETL._anneal_coordinate(data, 'x', 'y') 

197 data = RepoETL._center_coordinate(data, 'x', 'y') 

198 

199 data.sort_values('fullpath', inplace=True) 

200 data.reset_index(drop=True, inplace=True) 

201 

202 cols = [ 

203 'node_name', 

204 'node_type', 

205 'x', 

206 'y', 

207 'dependencies', 

208 'subpackages', 

209 'fullpath', 

210 ] 

211 data = data[cols] 

212 return data 

213 

214 @staticmethod 

215 def _calculate_coordinates(data): 

216 # type: (DataFrame) -> DataFrame 

217 ''' 

218 Calculate inital x, y coordinates for each node in given DataFrame. 

219 Node are startified by type along the y axis. 

220 

221 Args: 

222 DataFrame: DataFrame of nodes. 

223 

224 Returns: 

225 DataFrame: DataFrame with x and y coordinate columns. 

226 ''' 

227 # set initial node coordinates 

228 data['y'] = 0 

229 for item in ['module', 'subpackage', 'library']: 

230 mask = data.node_type == item 

231 n = data[mask].shape[0] 

232 

233 index = data[mask].index 

234 data.loc[index, 'x'] = list(range(n)) 

235 

236 # move non-library nodes down the y axis according to how nested 

237 # they are 

238 if item != 'library': 

239 data.loc[index, 'y'] = data.loc[index, 'node_name']\ 

240 .apply(lambda x: len(x.split('.'))) 

241 

242 # move all module nodes beneath supackage nodes on the y axis 

243 max_ = data[data.node_type == 'subpackage'].y.max() 

244 index = data[data.node_type == 'module'].index 

245 data.loc[index, 'y'] += max_ 

246 data.loc[index, 'y'] += data.loc[index, 'subpackages'].apply(len) 

247 

248 # reverse y axis 

249 max_ = data.y.max() 

250 data.y = -1 * data.y + max_ 

251 

252 return data 

253 

254 @staticmethod 

255 def _anneal_coordinate(data, anneal_axis='x', pin_axis='y', iterations=10): 

256 # type: (DataFrame, str, str, int) -> DataFrame 

257 ''' 

258 Iteratively align nodes in the anneal axis according to the mean 

259 position of their connected nodes. Node anneal coordinates are rectified 

260 at the end of each iteration according to a pin axis, so that they do 

261 not overlap. This mean that they are sorted at each level of the pin 

262 axis. 

263 

264 Args: 

265 data (DataFrame): DataFrame with x column. 

266 anneal_axis (str, optional): Coordinate column to be annealed. 

267 Default: 'x'. 

268 pin_axis (str, optional): Coordinate column to be held constant. 

269 Default: 'y'. 

270 iterations (int, optional): Number of times to update x coordinates. 

271 Default: 10. 

272 

273 Returns: 

274 DataFrame: DataFrame with annealed anneal axis coordinates. 

275 ''' 

276 x = anneal_axis 

277 y = pin_axis 

278 for iteration in range(iterations): 

279 # create directed graph from data 

280 graph = RepoETL._to_networkx_graph(data) 

281 

282 # reverse connectivity every other iteration 

283 if iteration % 2 == 0: 

284 graph = graph.reverse() 

285 

286 # get mean coordinate of each node in directed graph 

287 for name in graph.nodes: 

288 tree = networkx.bfs_tree(graph, name) 

289 mu = np.mean([graph.nodes[n][x] for n in tree]) 

290 graph.nodes[name][x] = mu 

291 

292 # update data coordinate column 

293 for node in graph.nodes: 

294 mask = data[data.node_name == node].index 

295 data.loc[mask, x] = graph.nodes[node][x] 

296 

297 # rectify data coordinate column, so that no two nodes overlap 

298 data.sort_values(x, inplace=True) 

299 for yi in data[y].unique(): 

300 mask = data[data[y] == yi].index 

301 values = data.loc[mask, x].tolist() 

302 values = list(range(len(values))) 

303 data.loc[mask, x] = values 

304 

305 return data 

306 

307 @staticmethod 

308 def _center_coordinate(data, center_axis='x', pin_axis='y'): 

309 # (DataFrame, str, str) -> DataFrame 

310 ''' 

311 Sorted center_axis coordinates at each level of the pin axis. 

312 

313 Args: 

314 data (DataFrame): DataFrame with x column. 

315 anneal_column (str, optional): Coordinate column to be annealed. 

316 Default: 'x'. 

317 pin_axis (str, optional): Coordinate column to be held constant. 

318 Default: 'y'. 

319 iterations (int, optional): Number of times to update x coordinates. 

320 Default: 10. 

321 

322 Returns: 

323 DataFrame: DataFrame with centered center axis coordinates. 

324 ''' 

325 x = center_axis 

326 y = pin_axis 

327 max_ = data[x].max() 

328 for yi in data[y].unique(): 

329 mask = data[data[y] == yi].index 

330 l_max = data.loc[mask, x].max() 

331 delta = max_ - l_max 

332 data.loc[mask, x] += (delta / 2) 

333 return data 

334 

335 @staticmethod 

336 def _to_networkx_graph(data): 

337 # (DataFrame) -> networkx.DiGraph 

338 ''' 

339 Converts given DataFrame into networkx directed graph. 

340 

341 Args: 

342 DataFrame: DataFrame of nodes. 

343 

344 Returns: 

345 networkx.DiGraph: Graph of nodes. 

346 ''' 

347 graph = networkx.DiGraph() 

348 data.apply( 

349 lambda x: graph.add_node( 

350 x.node_name, 

351 **{k: getattr(x, k) for k in x.index} 

352 ), 

353 axis=1 

354 ) 

355 

356 data.apply( 

357 lambda x: [graph.add_edge(p, x.node_name) for p in x.dependencies], 

358 axis=1 

359 ) 

360 return graph 

361 

362 def to_networkx_graph(self): 

363 # () -> networkx.DiGraph 

364 ''' 

365 Converts internal data into networkx directed graph. 

366 

367 Returns: 

368 networkx.DiGraph: Graph of nodes. 

369 ''' 

370 return RepoETL._to_networkx_graph(self._data) 

371 

372 def to_dot_graph(self, orient='tb', orthogonal_edges=False, color_scheme=None): 

373 # (str, bool, Optional[Dict[str, str]]) -> pydot.Dot 

374 ''' 

375 Converts internal data into pydot graph. 

376 

377 Args: 

378 orient (str, optional): Graph layout orientation. Default: tb. 

379 Options include: 

380 

381 * tb - top to bottom 

382 * bt - bottom to top 

383 * lr - left to right 

384 * rl - right to left 

385 orthogonal_edges (bool, optional): Whether graph edges should have 

386 non-right angles. Default: False. 

387 color_scheme: (dict, optional): Color scheme to be applied to graph. 

388 Default: rolling_pin.tools.COLOR_SCHEME 

389 

390 Raises: 

391 ValueError: If orient is invalid. 

392 

393 Returns: 

394 pydot.Dot: Dot graph of nodes. 

395 ''' 

396 orient = orient.lower() 

397 orientations = ['tb', 'bt', 'lr', 'rl'] 

398 if orient not in orientations: 

399 msg = f'Invalid orient value. {orient} not in {orientations}.' 

400 raise ValueError(msg) 

401 

402 # set color scheme of graph 

403 if color_scheme is None: 

404 color_scheme = rpt.COLOR_SCHEME 

405 

406 # create dot graph 

407 graph = self.to_networkx_graph() 

408 dot = networkx.drawing.nx_pydot.to_pydot(graph) 

409 

410 # set layout orientation 

411 dot.set_rankdir(orient.upper()) 

412 

413 # set graph background color 

414 dot.set_bgcolor(color_scheme['background']) 

415 

416 # set edge draw type 

417 if orthogonal_edges: 

418 dot.set_splines('ortho') 

419 

420 # set draw parameters for each node in graph 

421 for node in dot.get_nodes(): 

422 # set node shape, color and font attributes 

423 node.set_shape('rect') 

424 node.set_style('filled') 

425 node.set_color(color_scheme['node']) 

426 node.set_fillcolor(color_scheme['node']) 

427 node.set_fontname('Courier') 

428 

429 nx_node = re.sub('"', '', node.get_name()) 

430 nx_node = graph.nodes[nx_node] 

431 

432 # if networkx node has no attributes skip it 

433 # this should not ever occur but might 

434 if nx_node == {}: 

435 continue # pragma: no cover 

436 

437 # set node x, y coordinates 

438 node.set_pos(f"{nx_node['x']},{nx_node['y']}!") 

439 

440 # vary node font color by noe type 

441 if nx_node['node_type'] == 'library': 

442 node.set_fontcolor(color_scheme['node_library_font']) 

443 elif nx_node['node_type'] == 'subpackage': 

444 node.set_fontcolor(color_scheme['node_subpackage_font']) 

445 else: 

446 node.set_fontcolor(color_scheme['node_module_font']) 

447 

448 # set draw parameters for each edge in graph 

449 for edge in dot.get_edges(): 

450 # get networkx source node of edge 

451 nx_node = dot.get_node(edge.get_source()) 

452 nx_node = nx_node[0].get_name() 

453 nx_node = re.sub('"', '', nx_node) 

454 nx_node = graph.nodes[nx_node] 

455 

456 # if networkx source node has no attributes skip it 

457 # this should not ever occur but might 

458 if nx_node == {}: 

459 continue # pragma: no cover 

460 

461 # vary edge color by its source node type 

462 if nx_node['node_type'] == 'library': 

463 edge.set_color(color_scheme['edge_library']) 

464 elif nx_node['node_type'] == 'subpackage': 

465 edge.set_color(color_scheme['edge_subpackage']) 

466 else: 

467 # this line is actually covered by pytest doesn't think so 

468 edge.set_color(color_scheme['edge_module']) # pragma: no cover 

469 

470 return dot 

471 

472 def to_dataframe(self): 

473 # type: () -> DataFrame 

474 ''' 

475 Retruns: 

476 DataFrame: DataFrame of nodes representing repo modules. 

477 ''' 

478 return self._data.copy() 

479 

480 def to_html( 

481 self, 

482 layout='dot', 

483 orthogonal_edges=False, 

484 color_scheme=None, 

485 as_png=False 

486 ): 

487 # type: (str, bool, Optional[Dict[str, str]], bool) -> Union[HTML, Image] 

488 ''' 

489 For use in inline rendering of graph data in Jupyter Lab. 

490 

491 Args: 

492 layout (str, optional): Graph layout style. 

493 Options include: circo, dot, fdp, neato, sfdp, twopi. 

494 Default: dot. 

495 orthogonal_edges (bool, optional): Whether graph edges should have 

496 non-right angles. Default: False. 

497 color_scheme: (dict, optional): Color scheme to be applied to graph. 

498 Default: rolling_pin.tools.COLOR_SCHEME 

499 as_png (bool, optional): Display graph as a PNG image instead of 

500 SVG. Useful for display on Github. Default: False. 

501 

502 Returns: 

503 IPython.display.HTML: HTML object for inline display. 

504 ''' 

505 if color_scheme is None: 

506 color_scheme = rpt.COLOR_SCHEME 

507 

508 dot = self.to_dot_graph( 

509 orthogonal_edges=orthogonal_edges, 

510 color_scheme=color_scheme, 

511 ) 

512 return rpt.dot_to_html(dot, layout=layout, as_png=as_png) 

513 

514 def write( 

515 self, 

516 fullpath, 

517 layout='dot', 

518 orient='tb', 

519 orthogonal_edges=False, 

520 color_scheme=None 

521 ): 

522 # type: (Union[str, Path], str, str, bool, Optional[Dict[str, str]]) -> RepoETL 

523 ''' 

524 Writes internal data to a given filepath. 

525 Formats supported: svg, dot, png, json. 

526 

527 Args: 

528 fulllpath (str or Path): File to be written to. 

529 layout (str, optional): Graph layout style. 

530 Options include: circo, dot, fdp, neato, sfdp, twopi. Default: dot. 

531 orient (str, optional): Graph layout orientation. Default: tb. 

532 Options include: 

533 

534 * tb - top to bottom 

535 * bt - bottom to top 

536 * lr - left to right 

537 * rl - right to left 

538 orthogonal_edges (bool, optional): Whether graph edges should have 

539 non-right angles. Default: False. 

540 color_scheme: (dict, optional): Color scheme to be applied to graph. 

541 Default: rolling_pin.tools.COLOR_SCHEME 

542 

543 Raises: 

544 ValueError: If invalid file extension given. 

545 

546 Returns: 

547 RepoETL: Self. 

548 ''' 

549 if isinstance(fullpath, Path): 

550 fullpath = fullpath.absolute().as_posix() 

551 

552 _, ext = os.path.splitext(fullpath) 

553 ext = re.sub(r'^\.', '', ext) 

554 if re.search('^json$', ext, re.I): 

555 self._data.to_json(fullpath, orient='records') 

556 return self 

557 

558 if color_scheme is None: 

559 color_scheme = rpt.COLOR_SCHEME 

560 

561 graph = self.to_dot_graph( 

562 orient=orient, 

563 orthogonal_edges=orthogonal_edges, 

564 color_scheme=color_scheme, 

565 ) 

566 try: 

567 rpt.write_dot_graph(graph, fullpath, layout=layout,) 

568 except ValueError: 

569 msg = f'Invalid extension found: {ext}. ' 

570 msg += 'Valid extensions include: svg, dot, png, json.' 

571 raise ValueError(msg) 

572 return self