Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from typing import Any, Dict, Iterator, List, Optional, Union 

2from IPython.display import HTML 

3 

4import os 

5import re 

6from itertools import chain 

7from pathlib import Path 

8 

9import lunchbox.tools as lbt 

10import numpy as np 

11from pandas import DataFrame, Series 

12 

13import networkx 

14import rolling_pin.tools as tools 

15# ------------------------------------------------------------------------------ 

16 

17''' 

18Contains the RepoETL class, which is used for converted python repository module 

19dependencies into a directed graph. 

20''' 

21 

22 

23class RepoETL(): 

24 ''' 

25 RepoETL is a class for extracting 1st order dependencies of modules within a 

26 given repository. This information is stored internally as a DataFrame and 

27 can be rendered as networkx, pydot or SVG graphs. 

28 ''' 

29 def __init__( 

30 self, 

31 root, 

32 include_regex=r'.*\.py$', 

33 exclude_regex=r'(__init__|test_|_test|mock_)\.py$', 

34 ): 

35 # type: (Union[str, Path], str, str) -> None 

36 r''' 

37 Construct RepoETL instance. 

38 

39 Args: 

40 root (str or Path): Full path to repository root directory. 

41 include_regex (str, optional): Files to be included in recursive 

42 directy search. Default: '.*\.py$'. 

43 exclude_regex (str, optional): Files to be excluded in recursive 

44 directy search. Default: '(__init__|test_|_test|mock_)\.py$'. 

45 

46 Raises: 

47 ValueError: If include or exclude regex does not end in '\.py$'. 

48 ''' 

49 self._root = root # type: Union[str, Path] 

50 self._data = self._get_data(root, include_regex, exclude_regex) # type: DataFrame 

51 

52 @staticmethod 

53 def _get_imports(fullpath): 

54 # type: (Union[str, Path]) -> List[str] 

55 ''' 

56 Get's import statements from a given python module. 

57 

58 Args: 

59 fullpath (str or Path): Path to python module. 

60 

61 Returns: 

62 list(str): List of imported modules. 

63 ''' 

64 with open(fullpath) as f: 

65 data = f.readlines() # type: Union[List, Iterator] 

66 data = map(lambda x: x.strip('\n'), data) 

67 data = filter(lambda x: re.search('^import|^from', x), data) 

68 data = map(lambda x: re.sub('from (.*?) .*', '\\1', x), data) 

69 data = map(lambda x: re.sub(' as .*', '', x), data) 

70 data = map(lambda x: re.sub(' *#.*', '', x), data) 

71 data = map(lambda x: re.sub('import ', '', x), data) 

72 data = filter(lambda x: not lbt.is_standard_module(x), data) 

73 return list(data) 

74 

75 @staticmethod 

76 def _get_data( 

77 root, 

78 include_regex=r'.*\.py$', 

79 exclude_regex=r'(__init__|_test)\.py$', 

80 ): 

81 # type: (Union[str, Path], str, str) -> DataFrame 

82 r''' 

83 Recursively aggregates and filters all the files found with a given 

84 directory into a DataFrame. Data is used to create directed graphs. 

85 

86 DataFrame has these columns: 

87 

88 * node_name - name of node 

89 * node_type - type of node, can be [module, subpackage, library] 

90 * x - node's x coordinate 

91 * y - node's y coordinate 

92 * dependencies - parent nodes 

93 * subpackages - parent nodes of type subpackage 

94 * fullpath - fullpath to the module a node represents 

95 

96 Args: 

97 root (str or Path): Root directory to be searched. 

98 include_regex (str, optional): Files to be included in recursive 

99 directy search. Default: '.*\.py$'. 

100 exclude_regex (str, optional): Files to be excluded in recursive 

101 directy search. Default: '(__init__|_test)\.py$'. 

102 

103 Raises: 

104 ValueError: If include or exclude regex does not end in '\.py$'. 

105 FileNotFoundError: If no files are found after filtering. 

106 

107 Returns: 

108 DataFrame: DataFrame of file information. 

109 ''' 

110 files = tools.list_all_files(root) # type: Union[Iterator, List] 

111 if include_regex != '': 

112 if not include_regex.endswith(r'\.py$'): 

113 msg = f"Invalid include_regex: '{include_regex}'. " 

114 msg += r"Does not end in '.py$'." 

115 raise ValueError(msg) 

116 

117 files = filter( 

118 lambda x: re.search(include_regex, x.absolute().as_posix()), 

119 files 

120 ) 

121 if exclude_regex != '': 

122 files = filter( 

123 lambda x: not re.search(exclude_regex, x.absolute().as_posix()), 

124 files 

125 ) 

126 

127 files = list(files) 

128 if len(files) == 0: 

129 msg = f'No files found after filters in directory: {root}.' 

130 raise FileNotFoundError(msg) 

131 

132 # buid DataFrame of nodes and imported dependencies 

133 data = DataFrame() 

134 data['fullpath'] = files 

135 data.fullpath = data.fullpath.apply(lambda x: x.absolute().as_posix()) 

136 

137 data['node_name'] = data.fullpath\ 

138 .apply(lambda x: re.sub(root, '', x))\ 

139 .apply(lambda x: re.sub(r'\.py$', '', x))\ 

140 .apply(lambda x: re.sub('^/', '', x))\ 

141 .apply(lambda x: re.sub('/', '.', x)) 

142 

143 data['subpackages'] = data.node_name\ 

144 .apply(lambda x: tools.get_parent_fields(x, '.')).apply(lbt.get_ordered_unique) 

145 data.subpackages = data.subpackages\ 

146 .apply(lambda x: list(filter(lambda y: y != '', x))) 

147 

148 data['dependencies'] = data.fullpath\ 

149 .apply(RepoETL._get_imports).apply(lbt.get_ordered_unique) 

150 data.dependencies += data.node_name\ 

151 .apply(lambda x: ['.'.join(x.split('.')[:-1])]) 

152 data.dependencies = data.dependencies\ 

153 .apply(lambda x: list(filter(lambda y: y != '', x))) 

154 

155 data['node_type'] = 'module' 

156 

157 # add subpackages as nodes 

158 pkgs = set(chain(*data.subpackages.tolist())) # type: Any 

159 pkgs = pkgs.difference(data.node_name.tolist()) 

160 pkgs = sorted(list(pkgs)) 

161 pkgs = Series(pkgs)\ 

162 .apply( 

163 lambda x: dict( 

164 node_name=x, 

165 node_type='subpackage', 

166 dependencies=tools.get_parent_fields(x, '.'), 

167 subpackages=tools.get_parent_fields(x, '.'), 

168 )).tolist() 

169 pkgs = DataFrame(pkgs) 

170 data = data.append(pkgs, ignore_index=True, sort=True) 

171 

172 # add library dependencies as nodes 

173 libs = set(chain(*data.dependencies.tolist())) # type: Any 

174 libs = libs.difference(data.node_name.tolist()) 

175 libs = sorted(list(libs)) 

176 libs = Series(libs)\ 

177 .apply( 

178 lambda x: dict( 

179 node_name=x, 

180 node_type='library', 

181 dependencies=[], 

182 subpackages=[], 

183 )).tolist() 

184 libs = DataFrame(libs) 

185 data = data.append(libs, ignore_index=True, sort=True) 

186 

187 data.drop_duplicates('node_name', inplace=True) 

188 data.reset_index(drop=True, inplace=True) 

189 

190 # define node coordinates 

191 data['x'] = 0 

192 data['y'] = 0 

193 data = RepoETL._calculate_coordinates(data) 

194 data = RepoETL._anneal_coordinate(data, 'x', 'y') 

195 data = RepoETL._center_coordinate(data, 'x', 'y') 

196 

197 data.sort_values('fullpath', inplace=True) 

198 data.reset_index(drop=True, inplace=True) 

199 

200 cols = [ 

201 'node_name', 

202 'node_type', 

203 'x', 

204 'y', 

205 'dependencies', 

206 'subpackages', 

207 'fullpath', 

208 ] 

209 data = data[cols] 

210 return data 

211 

212 @staticmethod 

213 def _calculate_coordinates(data): 

214 # type: (DataFrame) -> DataFrame 

215 ''' 

216 Calculate inital x, y coordinates for each node in given DataFrame. 

217 Node are startified by type along the y axis. 

218 

219 Args: 

220 DataFrame: DataFrame of nodes. 

221 

222 Returns: 

223 DataFrame: DataFrame with x and y coordinate columns. 

224 ''' 

225 # set initial node coordinates 

226 data['y'] = 0 

227 for item in ['module', 'subpackage', 'library']: 

228 mask = data.node_type == item 

229 n = data[mask].shape[0] 

230 

231 index = data[mask].index 

232 data.loc[index, 'x'] = list(range(n)) 

233 

234 # move non-library nodes down the y axis according to how nested 

235 # they are 

236 if item != 'library': 

237 data.loc[index, 'y'] = data.loc[index, 'node_name']\ 

238 .apply(lambda x: len(x.split('.'))) 

239 

240 # move all module nodes beneath supackage nodes on the y axis 

241 max_ = data[data.node_type == 'subpackage'].y.max() 

242 index = data[data.node_type == 'module'].index 

243 data.loc[index, 'y'] += max_ 

244 data.loc[index, 'y'] += data.loc[index, 'subpackages'].apply(len) 

245 

246 # reverse y axis 

247 max_ = data.y.max() 

248 data.y = -1 * data.y + max_ 

249 

250 return data 

251 

252 @staticmethod 

253 def _anneal_coordinate(data, anneal_axis='x', pin_axis='y', iterations=10): 

254 # type: (DataFrame, str, str, int) -> DataFrame 

255 ''' 

256 Iteratively align nodes in the anneal axis according to the mean 

257 position of their connected nodes. Node anneal coordinates are rectified 

258 at the end of each iteration according to a pin axis, so that they do 

259 not overlap. This mean that they are sorted at each level of the pin 

260 axis. 

261 

262 Args: 

263 data (DataFrame): DataFrame with x column. 

264 anneal_axis (str, optional): Coordinate column to be annealed. 

265 Default: 'x'. 

266 pin_axis (str, optional): Coordinate column to be held constant. 

267 Default: 'y'. 

268 iterations (int, optional): Number of times to update x coordinates. 

269 Default: 10. 

270 

271 Returns: 

272 DataFrame: DataFrame with annealed anneal axis coordinates. 

273 ''' 

274 x = anneal_axis 

275 y = pin_axis 

276 for iteration in range(iterations): 

277 # create directed graph from data 

278 graph = RepoETL._to_networkx_graph(data) 

279 

280 # reverse connectivity every other iteration 

281 if iteration % 2 == 0: 

282 graph = graph.reverse() 

283 

284 # get mean coordinate of each node in directed graph 

285 for name in graph.nodes: 

286 tree = networkx.bfs_tree(graph, name) 

287 mu = np.mean([graph.nodes[n][x] for n in tree]) 

288 graph.nodes[name][x] = mu 

289 

290 # update data coordinate column 

291 for node in graph.nodes: 

292 mask = data[data.node_name == node].index 

293 data.loc[mask, x] = graph.nodes[node][x] 

294 

295 # rectify data coordinate column, so that no two nodes overlap 

296 data.sort_values(x, inplace=True) 

297 for yi in data[y].unique(): 

298 mask = data[data[y] == yi].index 

299 values = data.loc[mask, x].tolist() 

300 values = list(range(len(values))) 

301 data.loc[mask, x] = values 

302 

303 return data 

304 

305 @staticmethod 

306 def _center_coordinate(data, center_axis='x', pin_axis='y'): 

307 # (DataFrame, str, str) -> DataFrame 

308 ''' 

309 Sorted center_axis coordinates at each level of the pin axis. 

310 

311 Args: 

312 data (DataFrame): DataFrame with x column. 

313 anneal_column (str, optional): Coordinate column to be annealed. 

314 Default: 'x'. 

315 pin_axis (str, optional): Coordinate column to be held constant. 

316 Default: 'y'. 

317 iterations (int, optional): Number of times to update x coordinates. 

318 Default: 10. 

319 

320 Returns: 

321 DataFrame: DataFrame with centered center axis coordinates. 

322 ''' 

323 x = center_axis 

324 y = pin_axis 

325 max_ = data[x].max() 

326 for yi in data[y].unique(): 

327 mask = data[data[y] == yi].index 

328 l_max = data.loc[mask, x].max() 

329 delta = max_ - l_max 

330 data.loc[mask, x] += (delta / 2) 

331 return data 

332 

333 @staticmethod 

334 def _to_networkx_graph(data): 

335 # (DataFrame) -> networkx.DiGraph 

336 ''' 

337 Converts given DataFrame into networkx directed graph. 

338 

339 Args: 

340 DataFrame: DataFrame of nodes. 

341 

342 Returns: 

343 networkx.DiGraph: Graph of nodes. 

344 ''' 

345 graph = networkx.DiGraph() 

346 data.apply( 

347 lambda x: graph.add_node( 

348 x.node_name, 

349 **{k: getattr(x, k) for k in x.index} 

350 ), 

351 axis=1 

352 ) 

353 

354 data.apply( 

355 lambda x: [graph.add_edge(p, x.node_name) for p in x.dependencies], 

356 axis=1 

357 ) 

358 return graph 

359 

360 def to_networkx_graph(self): 

361 # () -> networkx.DiGraph 

362 ''' 

363 Converts internal data into networkx directed graph. 

364 

365 Returns: 

366 networkx.DiGraph: Graph of nodes. 

367 ''' 

368 return RepoETL._to_networkx_graph(self._data) 

369 

370 def to_dot_graph(self, orient='tb', orthogonal_edges=False, color_scheme=None): 

371 # (str, bool, Optional[Dict[str, str]]) -> pydot.Dot 

372 ''' 

373 Converts internal data into pydot graph. 

374 

375 Args: 

376 orient (str, optional): Graph layout orientation. Default: tb. 

377 Options include: 

378 

379 * tb - top to bottom 

380 * bt - bottom to top 

381 * lr - left to right 

382 * rl - right to left 

383 orthogonal_edges (bool, optional): Whether graph edges should have 

384 non-right angles. Default: False. 

385 color_scheme: (dict, optional): Color scheme to be applied to graph. 

386 Default: rolling_pin.tools.COLOR_SCHEME 

387 

388 Raises: 

389 ValueError: If orient is invalid. 

390 

391 Returns: 

392 pydot.Dot: Dot graph of nodes. 

393 ''' 

394 orient = orient.lower() 

395 orientations = ['tb', 'bt', 'lr', 'rl'] 

396 if orient not in orientations: 

397 msg = f'Invalid orient value. {orient} not in {orientations}.' 

398 raise ValueError(msg) 

399 

400 # set color scheme of graph 

401 if color_scheme is None: 

402 color_scheme = tools.COLOR_SCHEME 

403 

404 # create dot graph 

405 graph = self.to_networkx_graph() 

406 dot = networkx.drawing.nx_pydot.to_pydot(graph) 

407 

408 # set layout orientation 

409 dot.set_rankdir(orient.upper()) 

410 

411 # set graph background color 

412 dot.set_bgcolor(color_scheme['background']) 

413 

414 # set edge draw type 

415 if orthogonal_edges: 

416 dot.set_splines('ortho') 

417 

418 # set draw parameters for each node in graph 

419 for node in dot.get_nodes(): 

420 # set node shape, color and font attributes 

421 node.set_shape('rect') 

422 node.set_style('filled') 

423 node.set_color(color_scheme['node']) 

424 node.set_fillcolor(color_scheme['node']) 

425 node.set_fontname('Courier') 

426 

427 nx_node = re.sub('"', '', node.get_name()) 

428 nx_node = graph.nodes[nx_node] 

429 

430 # if networkx node has no attributes skip it 

431 # this should not ever occur but might 

432 if nx_node == {}: 

433 continue # pragma: no cover 

434 

435 # set node x, y coordinates 

436 node.set_pos(f"{nx_node['x']},{nx_node['y']}!") 

437 

438 # vary node font color by noe type 

439 if nx_node['node_type'] == 'library': 

440 node.set_fontcolor(color_scheme['node_library_font']) 

441 elif nx_node['node_type'] == 'subpackage': 

442 node.set_fontcolor(color_scheme['node_subpackage_font']) 

443 else: 

444 node.set_fontcolor(color_scheme['node_module_font']) 

445 

446 # set draw parameters for each edge in graph 

447 for edge in dot.get_edges(): 

448 # get networkx source node of edge 

449 nx_node = dot.get_node(edge.get_source()) 

450 nx_node = nx_node[0].get_name() 

451 nx_node = re.sub('"', '', nx_node) 

452 nx_node = graph.nodes[nx_node] 

453 

454 # if networkx source node has no attributes skip it 

455 # this should not ever occur but might 

456 if nx_node == {}: 

457 continue # pragma: no cover 

458 

459 # vary edge color by its source node type 

460 if nx_node['node_type'] == 'library': 

461 edge.set_color(color_scheme['edge_library']) 

462 elif nx_node['node_type'] == 'subpackage': 

463 edge.set_color(color_scheme['edge_subpackage']) 

464 else: 

465 # this line is actually covered by pytest doesn't think so 

466 edge.set_color(color_scheme['edge_module']) # pragma: no cover 

467 

468 return dot 

469 

470 def to_dataframe(self): 

471 # type: () -> DataFrame 

472 ''' 

473 Retruns: 

474 DataFrame: DataFrame of nodes representing repo modules. 

475 ''' 

476 return self._data.copy() 

477 

478 def to_html( 

479 self, 

480 layout='dot', 

481 orthogonal_edges=False, 

482 color_scheme=None, 

483 as_png=False 

484 ): 

485 # type: (str, bool, Optional[Dict[str, str]], bool) -> HTML 

486 ''' 

487 For use in inline rendering of graph data in Jupyter Lab. 

488 

489 Args: 

490 layout (str, optional): Graph layout style. 

491 Options include: circo, dot, fdp, neato, sfdp, twopi. 

492 Default: dot. 

493 orthogonal_edges (bool, optional): Whether graph edges should have 

494 non-right angles. Default: False. 

495 color_scheme: (dict, optional): Color scheme to be applied to graph. 

496 Default: rolling_pin.tools.COLOR_SCHEME 

497 as_png (bool, optional): Display graph as a PNG image instead of 

498 SVG. Useful for display on Github. Default: False. 

499 

500 Returns: 

501 IPython.display.HTML: HTML object for inline display. 

502 ''' 

503 if color_scheme is None: 

504 color_scheme = tools.COLOR_SCHEME 

505 

506 dot = self.to_dot_graph( 

507 orthogonal_edges=orthogonal_edges, 

508 color_scheme=color_scheme, 

509 ) 

510 return tools.dot_to_html(dot, layout=layout, as_png=as_png) 

511 

512 def write( 

513 self, 

514 fullpath, 

515 layout='dot', 

516 orient='tb', 

517 orthogonal_edges=False, 

518 color_scheme=None 

519 ): 

520 # type: (Union[str, Path], str, str, bool, Optional[Dict[str, str]]) -> RepoETL 

521 ''' 

522 Writes internal data to a given filepath. 

523 Formats supported: svg, dot, png, json. 

524 

525 Args: 

526 fulllpath (str or Path): File to be written to. 

527 layout (str, optional): Graph layout style. 

528 Options include: circo, dot, fdp, neato, sfdp, twopi. Default: dot. 

529 orient (str, optional): Graph layout orientation. Default: tb. 

530 Options include: 

531 

532 * tb - top to bottom 

533 * bt - bottom to top 

534 * lr - left to right 

535 * rl - right to left 

536 orthogonal_edges (bool, optional): Whether graph edges should have 

537 non-right angles. Default: False. 

538 color_scheme: (dict, optional): Color scheme to be applied to graph. 

539 Default: rolling_pin.tools.COLOR_SCHEME 

540 

541 Raises: 

542 ValueError: If invalid file extension given. 

543 

544 Returns: 

545 RepoETL: Self. 

546 ''' 

547 if isinstance(fullpath, Path): 

548 fullpath = fullpath.absolute().as_posix() 

549 

550 _, ext = os.path.splitext(fullpath) 

551 ext = re.sub(r'^\.', '', ext) 

552 if re.search('^json$', ext, re.I): 

553 self._data.to_json(fullpath, orient='records') 

554 return self 

555 

556 if color_scheme is None: 

557 color_scheme = tools.COLOR_SCHEME 

558 

559 graph = self.to_dot_graph( 

560 orient=orient, 

561 orthogonal_edges=orthogonal_edges, 

562 color_scheme=color_scheme, 

563 ) 

564 try: 

565 tools.write_dot_graph(graph, fullpath, layout=layout,) 

566 except ValueError: 

567 msg = f'Invalid extension found: {ext}. ' 

568 msg += 'Valid extensions include: svg, dot, png, json.' 

569 raise ValueError(msg) 

570 return self