Coverage for /home/ubuntu/rolling-pin/python/rolling_pin/repo_etl.py: 100%
202 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-11-15 00:43 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-11-15 00:43 +0000
1from typing import Any, Dict, Iterator, List, Optional, Union # noqa: F401
2from IPython.display import HTML, Image # noqa: F401
4from itertools import chain
5from pathlib import Path
6import os
7import re
9from pandas import DataFrame, Series
10import lunchbox.tools as lbt
11import networkx
12import numpy as np
13import pandas as pd
15import rolling_pin.tools as rpt
16# ------------------------------------------------------------------------------
18'''
19Contains the RepoETL class, which is used for converted python repository module
20dependencies into a directed graph.
21'''
24class RepoETL():
25 '''
26 RepoETL is a class for extracting 1st order dependencies of modules within a
27 given repository. This information is stored internally as a DataFrame and
28 can be rendered as networkx, pydot or SVG graphs.
29 '''
30 def __init__(
31 self,
32 root,
33 include_regex=r'.*\.py$',
34 exclude_regex=r'(__init__|test_|_test|mock_)\.py$',
35 ):
36 # type: (Union[str, Path], str, str) -> None
37 r'''
38 Construct RepoETL instance.
40 Args:
41 root (str or Path): Full path to repository root directory.
42 include_regex (str, optional): Files to be included in recursive
43 directy search. Default: '.*\.py$'.
44 exclude_regex (str, optional): Files to be excluded in recursive
45 directy search. Default: '(__init__|test_|_test|mock_)\.py$'.
47 Raises:
48 ValueError: If include or exclude regex does not end in '\.py$'.
49 '''
50 self._root = root # type: Union[str, Path]
51 self._data = self._get_data(root, include_regex, exclude_regex) # type: DataFrame
53 @staticmethod
54 def _get_imports(fullpath):
55 # type: (Union[str, Path]) -> List[str]
56 '''
57 Get's import statements from a given python module.
59 Args:
60 fullpath (str or Path): Path to python module.
62 Returns:
63 list(str): List of imported modules.
64 '''
65 with open(fullpath) as f:
66 data = f.readlines() # type: Union[List, Iterator]
67 data = map(lambda x: x.strip('\n'), data)
68 data = filter(lambda x: re.search('^import|^from', x), data)
69 data = map(lambda x: re.sub('from (.*?) .*', '\\1', x), data)
70 data = map(lambda x: re.sub(' as .*', '', x), data)
71 data = map(lambda x: re.sub(' *#.*', '', x), data)
72 data = map(lambda x: re.sub('import ', '', x), data)
73 data = filter(lambda x: not lbt.is_standard_module(x), data)
74 return list(data)
76 @staticmethod
77 def _get_data(
78 root,
79 include_regex=r'.*\.py$',
80 exclude_regex=r'(__init__|_test)\.py$',
81 ):
82 # type: (Union[str, Path], str, str) -> DataFrame
83 r'''
84 Recursively aggregates and filters all the files found with a given
85 directory into a DataFrame. Data is used to create directed graphs.
87 DataFrame has these columns:
89 * node_name - name of node
90 * node_type - type of node, can be [module, subpackage, library]
91 * x - node's x coordinate
92 * y - node's y coordinate
93 * dependencies - parent nodes
94 * subpackages - parent nodes of type subpackage
95 * fullpath - fullpath to the module a node represents
97 Args:
98 root (str or Path): Root directory to be searched.
99 include_regex (str, optional): Files to be included in recursive
100 directy search. Default: '.*\.py$'.
101 exclude_regex (str, optional): Files to be excluded in recursive
102 directy search. Default: '(__init__|_test)\.py$'.
104 Raises:
105 ValueError: If include or exclude regex does not end in '\.py$'.
106 FileNotFoundError: If no files are found after filtering.
108 Returns:
109 DataFrame: DataFrame of file information.
110 '''
111 root = Path(root).as_posix()
112 files = rpt.list_all_files(root) # type: Union[Iterator, List]
113 if include_regex != '':
114 if not include_regex.endswith(r'\.py$'):
115 msg = f"Invalid include_regex: '{include_regex}'. "
116 msg += r"Does not end in '.py$'."
117 raise ValueError(msg)
119 files = filter(
120 lambda x: re.search(include_regex, x.absolute().as_posix()),
121 files
122 )
123 if exclude_regex != '':
124 files = filter(
125 lambda x: not re.search(exclude_regex, x.absolute().as_posix()),
126 files
127 )
129 files = list(files)
130 if len(files) == 0:
131 msg = f'No files found after filters in directory: {root}.'
132 raise FileNotFoundError(msg)
134 # buid DataFrame of nodes and imported dependencies
135 data = DataFrame()
136 data['fullpath'] = files
137 data.fullpath = data.fullpath.apply(lambda x: x.absolute().as_posix())
139 data['node_name'] = data.fullpath\
140 .apply(lambda x: re.sub(root, '', x))\
141 .apply(lambda x: re.sub(r'\.py$', '', x))\
142 .apply(lambda x: re.sub('^/', '', x))\
143 .apply(lambda x: re.sub('/', '.', x))
145 data['subpackages'] = data.node_name\
146 .apply(lambda x: rpt.get_parent_fields(x, '.')).apply(lbt.get_ordered_unique)
147 data.subpackages = data.subpackages\
148 .apply(lambda x: list(filter(lambda y: y != '', x)))
150 data['dependencies'] = data.fullpath\
151 .apply(RepoETL._get_imports).apply(lbt.get_ordered_unique)
152 data.dependencies += data.node_name\
153 .apply(lambda x: ['.'.join(x.split('.')[:-1])])
154 data.dependencies = data.dependencies\
155 .apply(lambda x: list(filter(lambda y: y != '', x)))
157 data['node_type'] = 'module'
159 # add subpackages as nodes
160 pkgs = set(chain(*data.subpackages.tolist())) # type: Any
161 pkgs = pkgs.difference(data.node_name.tolist())
162 pkgs = sorted(list(pkgs))
163 pkgs = Series(pkgs)\
164 .apply(
165 lambda x: dict(
166 node_name=x,
167 node_type='subpackage',
168 dependencies=rpt.get_parent_fields(x, '.'),
169 subpackages=rpt.get_parent_fields(x, '.'),
170 )).tolist()
171 pkgs = DataFrame(pkgs)
172 data = pd.concat([data, pkgs], ignore_index=True, sort=True)
174 # add library dependencies as nodes
175 libs = set(chain(*data.dependencies.tolist())) # type: Any
176 libs = libs.difference(data.node_name.tolist())
177 libs = sorted(list(libs))
178 libs = Series(libs)\
179 .apply(
180 lambda x: dict(
181 node_name=x,
182 node_type='library',
183 dependencies=[],
184 subpackages=[],
185 )).tolist()
186 libs = DataFrame(libs)
187 data = pd.concat([data, libs], ignore_index=True, sort=True)
189 data.drop_duplicates('node_name', inplace=True)
190 data.reset_index(drop=True, inplace=True)
192 # define node coordinates
193 data['x'] = 0
194 data['y'] = 0
195 data = RepoETL._calculate_coordinates(data)
196 data = RepoETL._anneal_coordinate(data, 'x', 'y')
197 data = RepoETL._center_coordinate(data, 'x', 'y')
199 data.sort_values('fullpath', inplace=True)
200 data.reset_index(drop=True, inplace=True)
202 cols = [
203 'node_name',
204 'node_type',
205 'x',
206 'y',
207 'dependencies',
208 'subpackages',
209 'fullpath',
210 ]
211 data = data[cols]
212 return data
214 @staticmethod
215 def _calculate_coordinates(data):
216 # type: (DataFrame) -> DataFrame
217 '''
218 Calculate inital x, y coordinates for each node in given DataFrame.
219 Node are startified by type along the y axis.
221 Args:
222 DataFrame: DataFrame of nodes.
224 Returns:
225 DataFrame: DataFrame with x and y coordinate columns.
226 '''
227 # set initial node coordinates
228 data['y'] = 0
229 for item in ['module', 'subpackage', 'library']:
230 mask = data.node_type == item
231 n = data[mask].shape[0]
233 index = data[mask].index
234 data.loc[index, 'x'] = list(range(n))
236 # move non-library nodes down the y axis according to how nested
237 # they are
238 if item != 'library':
239 data.loc[index, 'y'] = data.loc[index, 'node_name']\
240 .apply(lambda x: len(x.split('.')))
242 # move all module nodes beneath supackage nodes on the y axis
243 max_ = data[data.node_type == 'subpackage'].y.max()
244 index = data[data.node_type == 'module'].index
245 data.loc[index, 'y'] += max_
246 data.loc[index, 'y'] += data.loc[index, 'subpackages'].apply(len)
248 # reverse y axis
249 max_ = data.y.max()
250 data.y = -1 * data.y + max_
252 return data
254 @staticmethod
255 def _anneal_coordinate(data, anneal_axis='x', pin_axis='y', iterations=10):
256 # type: (DataFrame, str, str, int) -> DataFrame
257 '''
258 Iteratively align nodes in the anneal axis according to the mean
259 position of their connected nodes. Node anneal coordinates are rectified
260 at the end of each iteration according to a pin axis, so that they do
261 not overlap. This mean that they are sorted at each level of the pin
262 axis.
264 Args:
265 data (DataFrame): DataFrame with x column.
266 anneal_axis (str, optional): Coordinate column to be annealed.
267 Default: 'x'.
268 pin_axis (str, optional): Coordinate column to be held constant.
269 Default: 'y'.
270 iterations (int, optional): Number of times to update x coordinates.
271 Default: 10.
273 Returns:
274 DataFrame: DataFrame with annealed anneal axis coordinates.
275 '''
276 x = anneal_axis
277 y = pin_axis
278 for iteration in range(iterations):
279 # create directed graph from data
280 graph = RepoETL._to_networkx_graph(data)
282 # reverse connectivity every other iteration
283 if iteration % 2 == 0:
284 graph = graph.reverse()
286 # get mean coordinate of each node in directed graph
287 for name in graph.nodes:
288 tree = networkx.bfs_tree(graph, name)
289 mu = np.mean([graph.nodes[n][x] for n in tree])
290 graph.nodes[name][x] = mu
292 # update data coordinate column
293 for node in graph.nodes:
294 mask = data[data.node_name == node].index
295 data.loc[mask, x] = graph.nodes[node][x]
297 # rectify data coordinate column, so that no two nodes overlap
298 data.sort_values(x, inplace=True)
299 for yi in data[y].unique():
300 mask = data[data[y] == yi].index
301 values = data.loc[mask, x].tolist()
302 values = list(range(len(values)))
303 data.loc[mask, x] = values
305 return data
307 @staticmethod
308 def _center_coordinate(data, center_axis='x', pin_axis='y'):
309 # (DataFrame, str, str) -> DataFrame
310 '''
311 Sorted center_axis coordinates at each level of the pin axis.
313 Args:
314 data (DataFrame): DataFrame with x column.
315 anneal_column (str, optional): Coordinate column to be annealed.
316 Default: 'x'.
317 pin_axis (str, optional): Coordinate column to be held constant.
318 Default: 'y'.
319 iterations (int, optional): Number of times to update x coordinates.
320 Default: 10.
322 Returns:
323 DataFrame: DataFrame with centered center axis coordinates.
324 '''
325 x = center_axis
326 y = pin_axis
327 max_ = data[x].max()
328 for yi in data[y].unique():
329 mask = data[data[y] == yi].index
330 l_max = data.loc[mask, x].max()
331 delta = max_ - l_max
332 data.loc[mask, x] += (delta / 2)
333 return data
335 @staticmethod
336 def _to_networkx_graph(data):
337 # (DataFrame) -> networkx.DiGraph
338 '''
339 Converts given DataFrame into networkx directed graph.
341 Args:
342 DataFrame: DataFrame of nodes.
344 Returns:
345 networkx.DiGraph: Graph of nodes.
346 '''
347 graph = networkx.DiGraph()
348 data.apply(
349 lambda x: graph.add_node(
350 x.node_name,
351 **{k: getattr(x, k) for k in x.index}
352 ),
353 axis=1
354 )
356 data.apply(
357 lambda x: [graph.add_edge(p, x.node_name) for p in x.dependencies],
358 axis=1
359 )
360 return graph
362 def to_networkx_graph(self):
363 # () -> networkx.DiGraph
364 '''
365 Converts internal data into networkx directed graph.
367 Returns:
368 networkx.DiGraph: Graph of nodes.
369 '''
370 return RepoETL._to_networkx_graph(self._data)
372 def to_dot_graph(self, orient='tb', orthogonal_edges=False, color_scheme=None):
373 # (str, bool, Optional[Dict[str, str]]) -> pydot.Dot
374 '''
375 Converts internal data into pydot graph.
377 Args:
378 orient (str, optional): Graph layout orientation. Default: tb.
379 Options include:
381 * tb - top to bottom
382 * bt - bottom to top
383 * lr - left to right
384 * rl - right to left
385 orthogonal_edges (bool, optional): Whether graph edges should have
386 non-right angles. Default: False.
387 color_scheme: (dict, optional): Color scheme to be applied to graph.
388 Default: rolling_pin.tools.COLOR_SCHEME
390 Raises:
391 ValueError: If orient is invalid.
393 Returns:
394 pydot.Dot: Dot graph of nodes.
395 '''
396 orient = orient.lower()
397 orientations = ['tb', 'bt', 'lr', 'rl']
398 if orient not in orientations:
399 msg = f'Invalid orient value. {orient} not in {orientations}.'
400 raise ValueError(msg)
402 # set color scheme of graph
403 if color_scheme is None:
404 color_scheme = rpt.COLOR_SCHEME
406 # create dot graph
407 graph = self.to_networkx_graph()
408 dot = networkx.drawing.nx_pydot.to_pydot(graph)
410 # set layout orientation
411 dot.set_rankdir(orient.upper())
413 # set graph background color
414 dot.set_bgcolor(color_scheme['background'])
416 # set edge draw type
417 if orthogonal_edges:
418 dot.set_splines('ortho')
420 # set draw parameters for each node in graph
421 for node in dot.get_nodes():
422 # set node shape, color and font attributes
423 node.set_shape('rect')
424 node.set_style('filled')
425 node.set_color(color_scheme['node'])
426 node.set_fillcolor(color_scheme['node'])
427 node.set_fontname('Courier')
429 nx_node = re.sub('"', '', node.get_name())
430 nx_node = graph.nodes[nx_node]
432 # if networkx node has no attributes skip it
433 # this should not ever occur but might
434 if nx_node == {}:
435 continue # pragma: no cover
437 # set node x, y coordinates
438 node.set_pos(f"{nx_node['x']},{nx_node['y']}!")
440 # vary node font color by noe type
441 if nx_node['node_type'] == 'library':
442 node.set_fontcolor(color_scheme['node_library_font'])
443 elif nx_node['node_type'] == 'subpackage':
444 node.set_fontcolor(color_scheme['node_subpackage_font'])
445 else:
446 node.set_fontcolor(color_scheme['node_module_font'])
448 # set draw parameters for each edge in graph
449 for edge in dot.get_edges():
450 # get networkx source node of edge
451 nx_node = dot.get_node(edge.get_source())
452 nx_node = nx_node[0].get_name()
453 nx_node = re.sub('"', '', nx_node)
454 nx_node = graph.nodes[nx_node]
456 # if networkx source node has no attributes skip it
457 # this should not ever occur but might
458 if nx_node == {}:
459 continue # pragma: no cover
461 # vary edge color by its source node type
462 if nx_node['node_type'] == 'library':
463 edge.set_color(color_scheme['edge_library'])
464 elif nx_node['node_type'] == 'subpackage':
465 edge.set_color(color_scheme['edge_subpackage'])
466 else:
467 # this line is actually covered by pytest doesn't think so
468 edge.set_color(color_scheme['edge_module']) # pragma: no cover
470 return dot
472 def to_dataframe(self):
473 # type: () -> DataFrame
474 '''
475 Retruns:
476 DataFrame: DataFrame of nodes representing repo modules.
477 '''
478 return self._data.copy()
480 def to_html(
481 self,
482 layout='dot',
483 orthogonal_edges=False,
484 color_scheme=None,
485 as_png=False
486 ):
487 # type: (str, bool, Optional[Dict[str, str]], bool) -> Union[HTML, Image]
488 '''
489 For use in inline rendering of graph data in Jupyter Lab.
491 Args:
492 layout (str, optional): Graph layout style.
493 Options include: circo, dot, fdp, neato, sfdp, twopi.
494 Default: dot.
495 orthogonal_edges (bool, optional): Whether graph edges should have
496 non-right angles. Default: False.
497 color_scheme: (dict, optional): Color scheme to be applied to graph.
498 Default: rolling_pin.tools.COLOR_SCHEME
499 as_png (bool, optional): Display graph as a PNG image instead of
500 SVG. Useful for display on Github. Default: False.
502 Returns:
503 IPython.display.HTML: HTML object for inline display.
504 '''
505 if color_scheme is None:
506 color_scheme = rpt.COLOR_SCHEME
508 dot = self.to_dot_graph(
509 orthogonal_edges=orthogonal_edges,
510 color_scheme=color_scheme,
511 )
512 return rpt.dot_to_html(dot, layout=layout, as_png=as_png)
514 def write(
515 self,
516 fullpath,
517 layout='dot',
518 orient='tb',
519 orthogonal_edges=False,
520 color_scheme=None
521 ):
522 # type: (Union[str, Path], str, str, bool, Optional[Dict[str, str]]) -> RepoETL
523 '''
524 Writes internal data to a given filepath.
525 Formats supported: svg, dot, png, json.
527 Args:
528 fulllpath (str or Path): File to be written to.
529 layout (str, optional): Graph layout style.
530 Options include: circo, dot, fdp, neato, sfdp, twopi. Default: dot.
531 orient (str, optional): Graph layout orientation. Default: tb.
532 Options include:
534 * tb - top to bottom
535 * bt - bottom to top
536 * lr - left to right
537 * rl - right to left
538 orthogonal_edges (bool, optional): Whether graph edges should have
539 non-right angles. Default: False.
540 color_scheme: (dict, optional): Color scheme to be applied to graph.
541 Default: rolling_pin.tools.COLOR_SCHEME
543 Raises:
544 ValueError: If invalid file extension given.
546 Returns:
547 RepoETL: Self.
548 '''
549 if isinstance(fullpath, Path):
550 fullpath = fullpath.absolute().as_posix()
552 _, ext = os.path.splitext(fullpath)
553 ext = re.sub(r'^\.', '', ext)
554 if re.search('^json$', ext, re.I):
555 self._data.to_json(fullpath, orient='records')
556 return self
558 if color_scheme is None:
559 color_scheme = rpt.COLOR_SCHEME
561 graph = self.to_dot_graph(
562 orient=orient,
563 orthogonal_edges=orthogonal_edges,
564 color_scheme=color_scheme,
565 )
566 try:
567 rpt.write_dot_graph(graph, fullpath, layout=layout,)
568 except ValueError:
569 msg = f'Invalid extension found: {ext}. '
570 msg += 'Valid extensions include: svg, dot, png, json.'
571 raise ValueError(msg)
572 return self