Coverage for /home/ubuntu/rolling-pin/python/rolling_pin/repo_etl.py: 100%
208 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-13 19:35 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-13 19:35 +0000
1from typing import Any, Dict, Iterator, List, Optional, Union # noqa: F401
2from IPython.display import HTML, Image # noqa: F401
4from itertools import chain
5from pathlib import Path
6import os
7import re
9from pandas import DataFrame, Series
10import lunchbox.tools as lbt
11import networkx
12import numpy as np
13import pandas as pd
15import rolling_pin.tools as rpt
16# ------------------------------------------------------------------------------
18'''
19Contains the RepoETL class, which is used for converted python repository module
20dependencies into a directed graph.
21'''
24class RepoETL():
25 '''
26 RepoETL is a class for extracting 1st order dependencies of modules within a
27 given repository. This information is stored internally as a DataFrame and
28 can be rendered as networkx, pydot or SVG graphs.
29 '''
30 def __init__(
31 self,
32 root,
33 include_regex=r'.*\.py$',
34 exclude_regex=r'(__init__|test_|_test|mock_)\.py$',
35 ):
36 # type: (Union[str, Path], str, str) -> None
37 r'''
38 Construct RepoETL instance.
40 Args:
41 root (str or Path): Full path to repository root directory.
42 include_regex (str, optional): Files to be included in recursive
43 directy search. Default: '.*\.py$'.
44 exclude_regex (str, optional): Files to be excluded in recursive
45 directy search. Default: '(__init__|test_|_test|mock_)\.py$'.
47 Raises:
48 ValueError: If include or exclude regex does not end in '\.py$'.
49 '''
50 self._root = root # type: Union[str, Path]
51 self._data = self._get_data(root, include_regex, exclude_regex) # type: DataFrame
53 @staticmethod
54 def _get_imports(fullpath):
55 # type: (Union[str, Path]) -> List[str]
56 '''
57 Get's import statements from a given python module.
59 Args:
60 fullpath (str or Path): Path to python module.
62 Returns:
63 list(str): List of imported modules.
64 '''
65 with open(fullpath) as f:
66 data = f.readlines() # type: Union[List, Iterator]
67 data = map(lambda x: x.strip('\n'), data)
68 data = filter(lambda x: re.search('^import|^from', x), data)
69 data = map(lambda x: re.sub('from (.*?) .*', '\\1', x), data)
70 data = map(lambda x: re.sub(' as .*', '', x), data)
71 data = map(lambda x: re.sub(' *#.*', '', x), data)
72 data = map(lambda x: re.sub('import ', '', x), data)
73 data = filter(lambda x: not lbt.is_standard_module(x), data)
74 return list(data)
76 @staticmethod
77 def _get_data(
78 root,
79 include_regex=r'.*\.py$',
80 exclude_regex=r'(__init__|_test)\.py$',
81 ):
82 # type: (Union[str, Path], str, str) -> DataFrame
83 r'''
84 Recursively aggregates and filters all the files found with a given
85 directory into a DataFrame. Data is used to create directed graphs.
87 DataFrame has these columns:
89 * node_name - name of node
90 * node_type - type of node, can be [module, subpackage, library]
91 * x - node's x coordinate
92 * y - node's y coordinate
93 * dependencies - parent nodes
94 * subpackages - parent nodes of type subpackage
95 * fullpath - fullpath to the module a node represents
97 Args:
98 root (str or Path): Root directory to be searched.
99 include_regex (str, optional): Files to be included in recursive
100 directy search. Default: '.*\.py$'.
101 exclude_regex (str, optional): Files to be excluded in recursive
102 directy search. Default: '(__init__|_test)\.py$'.
104 Raises:
105 ValueError: If include or exclude regex does not end in '\.py$'.
106 FileNotFoundError: If no files are found after filtering.
108 Returns:
109 DataFrame: DataFrame of file information.
110 '''
111 root = Path(root).as_posix()
112 files = rpt.list_all_files(root) # type: Union[Iterator, List]
113 if include_regex != '':
114 if not include_regex.endswith(r'\.py$'):
115 msg = f"Invalid include_regex: '{include_regex}'. "
116 msg += r"Does not end in '.py$'."
117 raise ValueError(msg)
119 files = filter(
120 lambda x: re.search(include_regex, x.absolute().as_posix()),
121 files
122 )
123 if exclude_regex != '':
124 files = filter(
125 lambda x: not re.search(exclude_regex, x.absolute().as_posix()),
126 files
127 )
129 files = list(files)
130 if len(files) == 0:
131 msg = f'No files found after filters in directory: {root}.'
132 raise FileNotFoundError(msg)
134 # buid DataFrame of nodes and imported dependencies
135 data = DataFrame()
136 data['fullpath'] = files
137 data.fullpath = data.fullpath.apply(lambda x: x.absolute().as_posix())
139 data['node_name'] = data.fullpath\
140 .apply(lambda x: re.sub(root, '', x))\
141 .apply(lambda x: re.sub(r'\.py$', '', x))\
142 .apply(lambda x: re.sub('^/', '', x))\
143 .apply(lambda x: re.sub('/', '.', x))
145 data['subpackages'] = data.node_name\
146 .apply(lambda x: rpt.get_parent_fields(x, '.')).apply(lbt.get_ordered_unique)
147 data.subpackages = data.subpackages\
148 .apply(lambda x: list(filter(lambda y: y != '', x)))
150 data['dependencies'] = data.fullpath\
151 .apply(RepoETL._get_imports).apply(lbt.get_ordered_unique)
152 data.dependencies += data.node_name\
153 .apply(lambda x: ['.'.join(x.split('.')[:-1])])
154 data.dependencies = data.dependencies\
155 .apply(lambda x: list(filter(lambda y: y != '', x)))
157 data['node_type'] = 'module'
159 # add subpackages as nodes
160 pkgs = set(chain(*data.subpackages.tolist())) # type: Any
161 pkgs = pkgs.difference(data.node_name.tolist())
162 pkgs = sorted(list(pkgs))
163 pkgs = Series(pkgs)\
164 .apply(
165 lambda x: dict(
166 node_name=x,
167 node_type='subpackage',
168 dependencies=rpt.get_parent_fields(x, '.'),
169 subpackages=rpt.get_parent_fields(x, '.'),
170 )).tolist()
171 pkgs = DataFrame(pkgs)
172 data = pd.concat([data, pkgs], ignore_index=True, sort=True)
174 # add library dependencies as nodes
175 libs = set(chain(*data.dependencies.tolist())) # type: Any
176 libs = libs.difference(data.node_name.tolist())
177 libs = sorted(list(libs))
178 libs = Series(libs)\
179 .apply(
180 lambda x: dict(
181 node_name=x,
182 node_type='library',
183 dependencies=[],
184 subpackages=[],
185 )).tolist()
186 libs = DataFrame(libs)
187 data = pd.concat([data, libs], ignore_index=True, sort=True)
189 data.drop_duplicates('node_name', inplace=True)
190 data.reset_index(drop=True, inplace=True)
192 # define node coordinates
193 data['x'] = 0
194 data['y'] = 0
195 data = RepoETL._calculate_coordinates(data)
196 data = RepoETL._anneal_coordinate(data, 'x', 'y')
197 data = RepoETL._center_coordinate(data, 'x', 'y')
199 data.sort_values('fullpath', inplace=True)
200 data.reset_index(drop=True, inplace=True)
202 cols = [
203 'node_name',
204 'node_type',
205 'x',
206 'y',
207 'dependencies',
208 'subpackages',
209 'fullpath',
210 ]
211 data = data[cols]
212 return data
214 @staticmethod
215 def _calculate_coordinates(data):
216 # type: (DataFrame) -> DataFrame
217 '''
218 Calculate inital x, y coordinates for each node in given DataFrame.
219 Node are startified by type along the y axis.
221 Args:
222 DataFrame: DataFrame of nodes.
224 Returns:
225 DataFrame: DataFrame with x and y coordinate columns.
226 '''
227 # set initial node coordinates
228 data['y'] = 0
229 for item in ['module', 'subpackage', 'library']:
230 mask = data.node_type == item
231 n = data[mask].shape[0]
233 index = data[mask].index
234 data.loc[index, 'x'] = list(range(n))
236 # move non-library nodes down the y axis according to how nested
237 # they are
238 if item != 'library':
239 data.loc[index, 'y'] = data.loc[index, 'node_name']\
240 .apply(lambda x: len(x.split('.')))
242 # move all module nodes beneath supackage nodes on the y axis
243 max_ = data[data.node_type == 'subpackage'].y.max()
244 index = data[data.node_type == 'module'].index
245 data.loc[index, 'y'] += max_
246 data.loc[index, 'y'] += data.loc[index, 'subpackages'].apply(len)
248 # reverse y axis
249 max_ = data.y.max()
250 data.y = -1 * data.y + max_
252 return data
254 @staticmethod
255 def _anneal_coordinate(data, anneal_axis='x', pin_axis='y', iterations=10):
256 # type: (DataFrame, str, str, int) -> DataFrame
257 '''
258 Iteratively align nodes in the anneal axis according to the mean
259 position of their connected nodes. Node anneal coordinates are rectified
260 at the end of each iteration according to a pin axis, so that they do
261 not overlap. This mean that they are sorted at each level of the pin
262 axis.
264 Args:
265 data (DataFrame): DataFrame with x column.
266 anneal_axis (str, optional): Coordinate column to be annealed.
267 Default: 'x'.
268 pin_axis (str, optional): Coordinate column to be held constant.
269 Default: 'y'.
270 iterations (int, optional): Number of times to update x coordinates.
271 Default: 10.
273 Returns:
274 DataFrame: DataFrame with annealed anneal axis coordinates.
275 '''
276 x = anneal_axis
277 y = pin_axis
278 data[x] = data[x].astype(float)
279 data[y] = data[y].astype(float)
280 for iteration in range(iterations):
281 # create directed graph from data
282 graph = RepoETL._to_networkx_graph(data)
284 # reverse connectivity every other iteration
285 if iteration % 2 == 0:
286 graph = graph.reverse()
288 # get mean coordinate of each node in directed graph
289 for name in graph.nodes:
290 tree = networkx.bfs_tree(graph, name)
291 mu = np.mean([graph.nodes[n][x] for n in tree])
292 graph.nodes[name][x] = mu
294 # update data coordinate column
295 for node in graph.nodes:
296 mask = data[data.node_name == node].index
297 data.loc[mask, x] = graph.nodes[node][x]
299 # rectify data coordinate column, so that no two nodes overlap
300 data.sort_values(x, inplace=True)
301 for yi in data[y].unique():
302 mask = data[data[y] == yi].index
303 values = data.loc[mask, x].tolist()
304 values = list(range(len(values)))
305 data.loc[mask, x] = values
307 return data
309 @staticmethod
310 def _center_coordinate(data, center_axis='x', pin_axis='y'):
311 # (DataFrame, str, str) -> DataFrame
312 '''
313 Sorted center_axis coordinates at each level of the pin axis.
315 Args:
316 data (DataFrame): DataFrame with x column.
317 anneal_column (str, optional): Coordinate column to be annealed.
318 Default: 'x'.
319 pin_axis (str, optional): Coordinate column to be held constant.
320 Default: 'y'.
321 iterations (int, optional): Number of times to update x coordinates.
322 Default: 10.
324 Returns:
325 DataFrame: DataFrame with centered center axis coordinates.
326 '''
327 x = center_axis
328 y = pin_axis
329 max_ = data[x].max()
330 for yi in data[y].unique():
331 mask = data[data[y] == yi].index
332 l_max = data.loc[mask, x].max()
333 delta = max_ - l_max
334 data.loc[mask, x] += (delta / 2)
335 return data
337 @staticmethod
338 def _to_networkx_graph(data, escape_chars=False):
339 # (DataFrame, bool) -> networkx.DiGraph
340 '''
341 Converts given DataFrame into networkx directed graph.
343 Args:
344 data (DataFrame): DataFrame of nodes.
345 escape_chars (bool, optional): Escape special characters. Used to
346 avoid dot file errors. Default: False.
348 Returns:
349 networkx.DiGraph: Graph of nodes.
350 '''
351 # escape periods for dot file interpolation
352 if escape_chars:
353 data = data.copy()
354 data.node_name = data.node_name \
355 .fillna('') \
356 .apply(lambda x: re.sub(r'\.', '\\.', x))
357 data.dependencies = data.dependencies \
358 .apply(lambda x: [re.sub(r'\.', '\\.', y) for y in x])
360 graph = networkx.DiGraph()
361 data.apply(
362 lambda x: graph.add_node(
363 x.node_name,
364 **{k: getattr(x, k) for k in x.index}
365 ),
366 axis=1
367 )
369 data.apply(
370 lambda x: [graph.add_edge(p, x.node_name) for p in x.dependencies],
371 axis=1
372 )
373 return graph
375 def to_networkx_graph(self):
376 # () -> networkx.DiGraph
377 '''
378 Converts internal data into networkx directed graph.
380 Returns:
381 networkx.DiGraph: Graph of nodes.
382 '''
383 return RepoETL._to_networkx_graph(self._data)
385 def to_dot_graph(self, orient='tb', orthogonal_edges=False, color_scheme=None):
386 # (str, bool, Optional[Dict[str, str]]) -> pydot.Dot
387 '''
388 Converts internal data into pydot graph.
390 Args:
391 orient (str, optional): Graph layout orientation. Default: tb.
392 Options include:
394 * tb - top to bottom
395 * bt - bottom to top
396 * lr - left to right
397 * rl - right to left
398 orthogonal_edges (bool, optional): Whether graph edges should have
399 non-right angles. Default: False.
400 color_scheme: (dict, optional): Color scheme to be applied to graph.
401 Default: rolling_pin.tools.COLOR_SCHEME
403 Raises:
404 ValueError: If orient is invalid.
406 Returns:
407 pydot.Dot: Dot graph of nodes.
408 '''
409 orient = orient.lower()
410 orientations = ['tb', 'bt', 'lr', 'rl']
411 if orient not in orientations:
412 msg = f'Invalid orient value. {orient} not in {orientations}.'
413 raise ValueError(msg)
415 # set color scheme of graph
416 if color_scheme is None:
417 color_scheme = rpt.COLOR_SCHEME
419 # create dot graph
420 graph = self._to_networkx_graph(self._data, escape_chars=True)
421 dot = networkx.drawing.nx_pydot.to_pydot(graph)
423 # set layout orientation
424 dot.set_rankdir(orient.upper())
426 # set graph background color
427 dot.set_bgcolor(color_scheme['background'])
429 # set edge draw type
430 if orthogonal_edges:
431 dot.set_splines('ortho')
433 # set draw parameters for each node in graph
434 for node in dot.get_nodes():
435 # set node shape, color and font attributes
436 node.set_shape('rect')
437 node.set_style('filled')
438 node.set_color(color_scheme['node'])
439 node.set_fillcolor(color_scheme['node'])
440 node.set_fontname('Courier')
442 nx_node = re.sub('"', '', node.get_name())
443 nx_node = graph.nodes[nx_node]
445 # if networkx node has no attributes skip it
446 # this should not ever occur but might
447 if nx_node == {}:
448 continue # pragma: no cover
450 # set node x, y coordinates
451 node.set_pos(f"{nx_node['x']},{nx_node['y']}!")
453 # vary node font color by noe type
454 if nx_node['node_type'] == 'library':
455 node.set_fontcolor(color_scheme['node_library_font'])
456 elif nx_node['node_type'] == 'subpackage':
457 node.set_fontcolor(color_scheme['node_subpackage_font'])
458 else:
459 node.set_fontcolor(color_scheme['node_module_font'])
461 # set draw parameters for each edge in graph
462 for edge in dot.get_edges():
463 # get networkx source node of edge
464 nx_node = dot.get_node(edge.get_source())
465 nx_node = nx_node[0].get_name()
466 nx_node = re.sub('"', '', nx_node)
467 nx_node = graph.nodes[nx_node]
469 # if networkx source node has no attributes skip it
470 # this should not ever occur but might
471 if nx_node == {}:
472 continue # pragma: no cover
474 # vary edge color by its source node type
475 if nx_node['node_type'] == 'library':
476 edge.set_color(color_scheme['edge_library'])
477 elif nx_node['node_type'] == 'subpackage':
478 edge.set_color(color_scheme['edge_subpackage'])
479 else:
480 # this line is actually covered by pytest doesn't think so
481 edge.set_color(color_scheme['edge_module']) # pragma: no cover
483 return dot
485 def to_dataframe(self):
486 # type: () -> DataFrame
487 '''
488 Retruns:
489 DataFrame: DataFrame of nodes representing repo modules.
490 '''
491 return self._data.copy()
493 def to_html(
494 self,
495 layout='dot',
496 orthogonal_edges=False,
497 color_scheme=None,
498 as_png=False
499 ):
500 # type: (str, bool, Optional[Dict[str, str]], bool) -> Union[HTML, Image]
501 '''
502 For use in inline rendering of graph data in Jupyter Lab.
504 Args:
505 layout (str, optional): Graph layout style.
506 Options include: circo, dot, fdp, neato, sfdp, twopi.
507 Default: dot.
508 orthogonal_edges (bool, optional): Whether graph edges should have
509 non-right angles. Default: False.
510 color_scheme: (dict, optional): Color scheme to be applied to graph.
511 Default: rolling_pin.tools.COLOR_SCHEME
512 as_png (bool, optional): Display graph as a PNG image instead of
513 SVG. Useful for display on Github. Default: False.
515 Returns:
516 IPython.display.HTML: HTML object for inline display.
517 '''
518 if color_scheme is None:
519 color_scheme = rpt.COLOR_SCHEME
521 dot = self.to_dot_graph(
522 orthogonal_edges=orthogonal_edges,
523 color_scheme=color_scheme,
524 )
525 return rpt.dot_to_html(dot, layout=layout, as_png=as_png)
527 def write(
528 self,
529 fullpath,
530 layout='dot',
531 orient='tb',
532 orthogonal_edges=False,
533 color_scheme=None
534 ):
535 # type: (Union[str, Path], str, str, bool, Optional[Dict[str, str]]) -> RepoETL
536 '''
537 Writes internal data to a given filepath.
538 Formats supported: svg, dot, png, json.
540 Args:
541 fulllpath (str or Path): File to be written to.
542 layout (str, optional): Graph layout style.
543 Options include: circo, dot, fdp, neato, sfdp, twopi. Default: dot.
544 orient (str, optional): Graph layout orientation. Default: tb.
545 Options include:
547 * tb - top to bottom
548 * bt - bottom to top
549 * lr - left to right
550 * rl - right to left
551 orthogonal_edges (bool, optional): Whether graph edges should have
552 non-right angles. Default: False.
553 color_scheme: (dict, optional): Color scheme to be applied to graph.
554 Default: rolling_pin.tools.COLOR_SCHEME
556 Raises:
557 ValueError: If invalid file extension given.
559 Returns:
560 RepoETL: Self.
561 '''
562 if isinstance(fullpath, Path):
563 fullpath = fullpath.absolute().as_posix()
565 _, ext = os.path.splitext(fullpath)
566 ext = re.sub(r'^\.', '', ext)
567 if re.search('^json$', ext, re.I):
568 self._data.to_json(fullpath, orient='records')
569 return self
571 if color_scheme is None:
572 color_scheme = rpt.COLOR_SCHEME
574 graph = self.to_dot_graph(
575 orient=orient,
576 orthogonal_edges=orthogonal_edges,
577 color_scheme=color_scheme,
578 )
579 try:
580 rpt.write_dot_graph(graph, fullpath, layout=layout,)
581 except ValueError:
582 msg = f'Invalid extension found: {ext}. '
583 msg += 'Valid extensions include: svg, dot, png, json.'
584 raise ValueError(msg)
585 return self