Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import Any, Dict, Iterator, List, Optional, Union
2from IPython.display import HTML
4import os
5import re
6from itertools import chain
7from pathlib import Path
9import lunchbox.tools as lbt
10import numpy as np
11from pandas import DataFrame, Series
13import networkx
14import rolling_pin.tools as tools
15# ------------------------------------------------------------------------------
17'''
18Contains the RepoETL class, which is used for converted python repository module
19dependencies into a directed graph.
20'''
23class RepoETL():
24 '''
25 RepoETL is a class for extracting 1st order dependencies of modules within a
26 given repository. This information is stored internally as a DataFrame and
27 can be rendered as networkx, pydot or SVG graphs.
28 '''
29 def __init__(
30 self,
31 root,
32 include_regex=r'.*\.py$',
33 exclude_regex=r'(__init__|test_|_test|mock_)\.py$',
34 ):
35 # type: (Union[str, Path], str, str) -> None
36 r'''
37 Construct RepoETL instance.
39 Args:
40 root (str or Path): Full path to repository root directory.
41 include_regex (str, optional): Files to be included in recursive
42 directy search. Default: '.*\.py$'.
43 exclude_regex (str, optional): Files to be excluded in recursive
44 directy search. Default: '(__init__|test_|_test|mock_)\.py$'.
46 Raises:
47 ValueError: If include or exclude regex does not end in '\.py$'.
48 '''
49 self._root = root # type: Union[str, Path]
50 self._data = self._get_data(root, include_regex, exclude_regex) # type: DataFrame
52 @staticmethod
53 def _get_imports(fullpath):
54 # type: (Union[str, Path]) -> List[str]
55 '''
56 Get's import statements from a given python module.
58 Args:
59 fullpath (str or Path): Path to python module.
61 Returns:
62 list(str): List of imported modules.
63 '''
64 with open(fullpath) as f:
65 data = f.readlines() # type: Union[List, Iterator]
66 data = map(lambda x: x.strip('\n'), data)
67 data = filter(lambda x: re.search('^import|^from', x), data)
68 data = map(lambda x: re.sub('from (.*?) .*', '\\1', x), data)
69 data = map(lambda x: re.sub(' as .*', '', x), data)
70 data = map(lambda x: re.sub(' *#.*', '', x), data)
71 data = map(lambda x: re.sub('import ', '', x), data)
72 data = filter(lambda x: not lbt.is_standard_module(x), data)
73 return list(data)
75 @staticmethod
76 def _get_data(
77 root,
78 include_regex=r'.*\.py$',
79 exclude_regex=r'(__init__|_test)\.py$',
80 ):
81 # type: (Union[str, Path], str, str) -> DataFrame
82 r'''
83 Recursively aggregates and filters all the files found with a given
84 directory into a DataFrame. Data is used to create directed graphs.
86 DataFrame has these columns:
88 * node_name - name of node
89 * node_type - type of node, can be [module, subpackage, library]
90 * x - node's x coordinate
91 * y - node's y coordinate
92 * dependencies - parent nodes
93 * subpackages - parent nodes of type subpackage
94 * fullpath - fullpath to the module a node represents
96 Args:
97 root (str or Path): Root directory to be searched.
98 include_regex (str, optional): Files to be included in recursive
99 directy search. Default: '.*\.py$'.
100 exclude_regex (str, optional): Files to be excluded in recursive
101 directy search. Default: '(__init__|_test)\.py$'.
103 Raises:
104 ValueError: If include or exclude regex does not end in '\.py$'.
105 FileNotFoundError: If no files are found after filtering.
107 Returns:
108 DataFrame: DataFrame of file information.
109 '''
110 files = tools.list_all_files(root) # type: Union[Iterator, List]
111 if include_regex != '':
112 if not include_regex.endswith(r'\.py$'):
113 msg = f"Invalid include_regex: '{include_regex}'. "
114 msg += r"Does not end in '.py$'."
115 raise ValueError(msg)
117 files = filter(
118 lambda x: re.search(include_regex, x.absolute().as_posix()),
119 files
120 )
121 if exclude_regex != '':
122 files = filter(
123 lambda x: not re.search(exclude_regex, x.absolute().as_posix()),
124 files
125 )
127 files = list(files)
128 if len(files) == 0:
129 msg = f'No files found after filters in directory: {root}.'
130 raise FileNotFoundError(msg)
132 # buid DataFrame of nodes and imported dependencies
133 data = DataFrame()
134 data['fullpath'] = files
135 data.fullpath = data.fullpath.apply(lambda x: x.absolute().as_posix())
137 data['node_name'] = data.fullpath\
138 .apply(lambda x: re.sub(root, '', x))\
139 .apply(lambda x: re.sub(r'\.py$', '', x))\
140 .apply(lambda x: re.sub('^/', '', x))\
141 .apply(lambda x: re.sub('/', '.', x))
143 data['subpackages'] = data.node_name\
144 .apply(lambda x: tools.get_parent_fields(x, '.')).apply(lbt.get_ordered_unique)
145 data.subpackages = data.subpackages\
146 .apply(lambda x: list(filter(lambda y: y != '', x)))
148 data['dependencies'] = data.fullpath\
149 .apply(RepoETL._get_imports).apply(lbt.get_ordered_unique)
150 data.dependencies += data.node_name\
151 .apply(lambda x: ['.'.join(x.split('.')[:-1])])
152 data.dependencies = data.dependencies\
153 .apply(lambda x: list(filter(lambda y: y != '', x)))
155 data['node_type'] = 'module'
157 # add subpackages as nodes
158 pkgs = set(chain(*data.subpackages.tolist())) # type: Any
159 pkgs = pkgs.difference(data.node_name.tolist())
160 pkgs = sorted(list(pkgs))
161 pkgs = Series(pkgs)\
162 .apply(
163 lambda x: dict(
164 node_name=x,
165 node_type='subpackage',
166 dependencies=tools.get_parent_fields(x, '.'),
167 subpackages=tools.get_parent_fields(x, '.'),
168 )).tolist()
169 pkgs = DataFrame(pkgs)
170 data = data.append(pkgs, ignore_index=True, sort=True)
172 # add library dependencies as nodes
173 libs = set(chain(*data.dependencies.tolist())) # type: Any
174 libs = libs.difference(data.node_name.tolist())
175 libs = sorted(list(libs))
176 libs = Series(libs)\
177 .apply(
178 lambda x: dict(
179 node_name=x,
180 node_type='library',
181 dependencies=[],
182 subpackages=[],
183 )).tolist()
184 libs = DataFrame(libs)
185 data = data.append(libs, ignore_index=True, sort=True)
187 data.drop_duplicates('node_name', inplace=True)
188 data.reset_index(drop=True, inplace=True)
190 # define node coordinates
191 data['x'] = 0
192 data['y'] = 0
193 data = RepoETL._calculate_coordinates(data)
194 data = RepoETL._anneal_coordinate(data, 'x', 'y')
195 data = RepoETL._center_coordinate(data, 'x', 'y')
197 data.sort_values('fullpath', inplace=True)
198 data.reset_index(drop=True, inplace=True)
200 cols = [
201 'node_name',
202 'node_type',
203 'x',
204 'y',
205 'dependencies',
206 'subpackages',
207 'fullpath',
208 ]
209 data = data[cols]
210 return data
212 @staticmethod
213 def _calculate_coordinates(data):
214 # type: (DataFrame) -> DataFrame
215 '''
216 Calculate inital x, y coordinates for each node in given DataFrame.
217 Node are startified by type along the y axis.
219 Args:
220 DataFrame: DataFrame of nodes.
222 Returns:
223 DataFrame: DataFrame with x and y coordinate columns.
224 '''
225 # set initial node coordinates
226 data['y'] = 0
227 for item in ['module', 'subpackage', 'library']:
228 mask = data.node_type == item
229 n = data[mask].shape[0]
231 index = data[mask].index
232 data.loc[index, 'x'] = list(range(n))
234 # move non-library nodes down the y axis according to how nested
235 # they are
236 if item != 'library':
237 data.loc[index, 'y'] = data.loc[index, 'node_name']\
238 .apply(lambda x: len(x.split('.')))
240 # move all module nodes beneath supackage nodes on the y axis
241 max_ = data[data.node_type == 'subpackage'].y.max()
242 index = data[data.node_type == 'module'].index
243 data.loc[index, 'y'] += max_
244 data.loc[index, 'y'] += data.loc[index, 'subpackages'].apply(len)
246 # reverse y axis
247 max_ = data.y.max()
248 data.y = -1 * data.y + max_
250 return data
252 @staticmethod
253 def _anneal_coordinate(data, anneal_axis='x', pin_axis='y', iterations=10):
254 # type: (DataFrame, str, str, int) -> DataFrame
255 '''
256 Iteratively align nodes in the anneal axis according to the mean
257 position of their connected nodes. Node anneal coordinates are rectified
258 at the end of each iteration according to a pin axis, so that they do
259 not overlap. This mean that they are sorted at each level of the pin
260 axis.
262 Args:
263 data (DataFrame): DataFrame with x column.
264 anneal_axis (str, optional): Coordinate column to be annealed.
265 Default: 'x'.
266 pin_axis (str, optional): Coordinate column to be held constant.
267 Default: 'y'.
268 iterations (int, optional): Number of times to update x coordinates.
269 Default: 10.
271 Returns:
272 DataFrame: DataFrame with annealed anneal axis coordinates.
273 '''
274 x = anneal_axis
275 y = pin_axis
276 for iteration in range(iterations):
277 # create directed graph from data
278 graph = RepoETL._to_networkx_graph(data)
280 # reverse connectivity every other iteration
281 if iteration % 2 == 0:
282 graph = graph.reverse()
284 # get mean coordinate of each node in directed graph
285 for name in graph.nodes:
286 tree = networkx.bfs_tree(graph, name)
287 mu = np.mean([graph.nodes[n][x] for n in tree])
288 graph.nodes[name][x] = mu
290 # update data coordinate column
291 for node in graph.nodes:
292 mask = data[data.node_name == node].index
293 data.loc[mask, x] = graph.nodes[node][x]
295 # rectify data coordinate column, so that no two nodes overlap
296 data.sort_values(x, inplace=True)
297 for yi in data[y].unique():
298 mask = data[data[y] == yi].index
299 values = data.loc[mask, x].tolist()
300 values = list(range(len(values)))
301 data.loc[mask, x] = values
303 return data
305 @staticmethod
306 def _center_coordinate(data, center_axis='x', pin_axis='y'):
307 # (DataFrame, str, str) -> DataFrame
308 '''
309 Sorted center_axis coordinates at each level of the pin axis.
311 Args:
312 data (DataFrame): DataFrame with x column.
313 anneal_column (str, optional): Coordinate column to be annealed.
314 Default: 'x'.
315 pin_axis (str, optional): Coordinate column to be held constant.
316 Default: 'y'.
317 iterations (int, optional): Number of times to update x coordinates.
318 Default: 10.
320 Returns:
321 DataFrame: DataFrame with centered center axis coordinates.
322 '''
323 x = center_axis
324 y = pin_axis
325 max_ = data[x].max()
326 for yi in data[y].unique():
327 mask = data[data[y] == yi].index
328 l_max = data.loc[mask, x].max()
329 delta = max_ - l_max
330 data.loc[mask, x] += (delta / 2)
331 return data
333 @staticmethod
334 def _to_networkx_graph(data):
335 # (DataFrame) -> networkx.DiGraph
336 '''
337 Converts given DataFrame into networkx directed graph.
339 Args:
340 DataFrame: DataFrame of nodes.
342 Returns:
343 networkx.DiGraph: Graph of nodes.
344 '''
345 graph = networkx.DiGraph()
346 data.apply(
347 lambda x: graph.add_node(
348 x.node_name,
349 **{k: getattr(x, k) for k in x.index}
350 ),
351 axis=1
352 )
354 data.apply(
355 lambda x: [graph.add_edge(p, x.node_name) for p in x.dependencies],
356 axis=1
357 )
358 return graph
360 def to_networkx_graph(self):
361 # () -> networkx.DiGraph
362 '''
363 Converts internal data into networkx directed graph.
365 Returns:
366 networkx.DiGraph: Graph of nodes.
367 '''
368 return RepoETL._to_networkx_graph(self._data)
370 def to_dot_graph(self, orient='tb', orthogonal_edges=False, color_scheme=None):
371 # (str, bool, Optional[Dict[str, str]]) -> pydot.Dot
372 '''
373 Converts internal data into pydot graph.
375 Args:
376 orient (str, optional): Graph layout orientation. Default: tb.
377 Options include:
379 * tb - top to bottom
380 * bt - bottom to top
381 * lr - left to right
382 * rl - right to left
383 orthogonal_edges (bool, optional): Whether graph edges should have
384 non-right angles. Default: False.
385 color_scheme: (dict, optional): Color scheme to be applied to graph.
386 Default: rolling_pin.tools.COLOR_SCHEME
388 Raises:
389 ValueError: If orient is invalid.
391 Returns:
392 pydot.Dot: Dot graph of nodes.
393 '''
394 orient = orient.lower()
395 orientations = ['tb', 'bt', 'lr', 'rl']
396 if orient not in orientations:
397 msg = f'Invalid orient value. {orient} not in {orientations}.'
398 raise ValueError(msg)
400 # set color scheme of graph
401 if color_scheme is None:
402 color_scheme = tools.COLOR_SCHEME
404 # create dot graph
405 graph = self.to_networkx_graph()
406 dot = networkx.drawing.nx_pydot.to_pydot(graph)
408 # set layout orientation
409 dot.set_rankdir(orient.upper())
411 # set graph background color
412 dot.set_bgcolor(color_scheme['background'])
414 # set edge draw type
415 if orthogonal_edges:
416 dot.set_splines('ortho')
418 # set draw parameters for each node in graph
419 for node in dot.get_nodes():
420 # set node shape, color and font attributes
421 node.set_shape('rect')
422 node.set_style('filled')
423 node.set_color(color_scheme['node'])
424 node.set_fillcolor(color_scheme['node'])
425 node.set_fontname('Courier')
427 nx_node = re.sub('"', '', node.get_name())
428 nx_node = graph.nodes[nx_node]
430 # if networkx node has no attributes skip it
431 # this should not ever occur but might
432 if nx_node == {}:
433 continue # pragma: no cover
435 # set node x, y coordinates
436 node.set_pos(f"{nx_node['x']},{nx_node['y']}!")
438 # vary node font color by noe type
439 if nx_node['node_type'] == 'library':
440 node.set_fontcolor(color_scheme['node_library_font'])
441 elif nx_node['node_type'] == 'subpackage':
442 node.set_fontcolor(color_scheme['node_subpackage_font'])
443 else:
444 node.set_fontcolor(color_scheme['node_module_font'])
446 # set draw parameters for each edge in graph
447 for edge in dot.get_edges():
448 # get networkx source node of edge
449 nx_node = dot.get_node(edge.get_source())
450 nx_node = nx_node[0].get_name()
451 nx_node = re.sub('"', '', nx_node)
452 nx_node = graph.nodes[nx_node]
454 # if networkx source node has no attributes skip it
455 # this should not ever occur but might
456 if nx_node == {}:
457 continue # pragma: no cover
459 # vary edge color by its source node type
460 if nx_node['node_type'] == 'library':
461 edge.set_color(color_scheme['edge_library'])
462 elif nx_node['node_type'] == 'subpackage':
463 edge.set_color(color_scheme['edge_subpackage'])
464 else:
465 # this line is actually covered by pytest doesn't think so
466 edge.set_color(color_scheme['edge_module']) # pragma: no cover
468 return dot
470 def to_dataframe(self):
471 # type: () -> DataFrame
472 '''
473 Retruns:
474 DataFrame: DataFrame of nodes representing repo modules.
475 '''
476 return self._data.copy()
478 def to_html(
479 self,
480 layout='dot',
481 orthogonal_edges=False,
482 color_scheme=None,
483 as_png=False
484 ):
485 # type: (str, bool, Optional[Dict[str, str]], bool) -> HTML
486 '''
487 For use in inline rendering of graph data in Jupyter Lab.
489 Args:
490 layout (str, optional): Graph layout style.
491 Options include: circo, dot, fdp, neato, sfdp, twopi.
492 Default: dot.
493 orthogonal_edges (bool, optional): Whether graph edges should have
494 non-right angles. Default: False.
495 color_scheme: (dict, optional): Color scheme to be applied to graph.
496 Default: rolling_pin.tools.COLOR_SCHEME
497 as_png (bool, optional): Display graph as a PNG image instead of
498 SVG. Useful for display on Github. Default: False.
500 Returns:
501 IPython.display.HTML: HTML object for inline display.
502 '''
503 if color_scheme is None:
504 color_scheme = tools.COLOR_SCHEME
506 dot = self.to_dot_graph(
507 orthogonal_edges=orthogonal_edges,
508 color_scheme=color_scheme,
509 )
510 return tools.dot_to_html(dot, layout=layout, as_png=as_png)
512 def write(
513 self,
514 fullpath,
515 layout='dot',
516 orient='tb',
517 orthogonal_edges=False,
518 color_scheme=None
519 ):
520 # type: (Union[str, Path], str, str, bool, Optional[Dict[str, str]]) -> RepoETL
521 '''
522 Writes internal data to a given filepath.
523 Formats supported: svg, dot, png, json.
525 Args:
526 fulllpath (str or Path): File to be written to.
527 layout (str, optional): Graph layout style.
528 Options include: circo, dot, fdp, neato, sfdp, twopi. Default: dot.
529 orient (str, optional): Graph layout orientation. Default: tb.
530 Options include:
532 * tb - top to bottom
533 * bt - bottom to top
534 * lr - left to right
535 * rl - right to left
536 orthogonal_edges (bool, optional): Whether graph edges should have
537 non-right angles. Default: False.
538 color_scheme: (dict, optional): Color scheme to be applied to graph.
539 Default: rolling_pin.tools.COLOR_SCHEME
541 Raises:
542 ValueError: If invalid file extension given.
544 Returns:
545 RepoETL: Self.
546 '''
547 if isinstance(fullpath, Path):
548 fullpath = fullpath.absolute().as_posix()
550 _, ext = os.path.splitext(fullpath)
551 ext = re.sub(r'^\.', '', ext)
552 if re.search('^json$', ext, re.I):
553 self._data.to_json(fullpath, orient='records')
554 return self
556 if color_scheme is None:
557 color_scheme = tools.COLOR_SCHEME
559 graph = self.to_dot_graph(
560 orient=orient,
561 orthogonal_edges=orthogonal_edges,
562 color_scheme=color_scheme,
563 )
564 try:
565 tools.write_dot_graph(graph, fullpath, layout=layout,)
566 except ValueError:
567 msg = f'Invalid extension found: {ext}. '
568 msg += 'Valid extensions include: svg, dot, png, json.'
569 raise ValueError(msg)
570 return self