Coverage for /home/ubuntu/shekels/python/shekels/core/data_tools.py: 99%
221 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-11-15 00:54 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-11-15 00:54 +0000
1from typing import Any, Dict, List, Optional, Union # noqa: F401
2import cufflinks as cf # noqa: F401
4from copy import copy
5from random import randint
6import datetime as dt
7import re
9from lunchbox.enforce import Enforce
10from pandas import DataFrame, DatetimeIndex
11from schematics.exceptions import DataError
12import lunchbox.tools as lbt
13import numpy as np
14import pandasql
15import pyparsing as pp
16import rolling_pin.blob_etl as rpb
17import webcolors
19from shekels.core.config import ConformAction
20import shekels.core.config as cfg
21import shekels.enforce.enforce_tools as eft
22# ------------------------------------------------------------------------------
25COLOR_COERCION_LUT = {
26 '#00CC96': '#5F95DE',
27 '#0D0887': '#444459',
28 '#19D3F3': '#5F95DE',
29 '#242424': '#242424',
30 '#276419': '#343434',
31 '#2A3F5F': '#444459',
32 '#343434': '#343434',
33 '#444444': '#444444',
34 '#46039F': '#444459',
35 '#4D9221': '#444444',
36 '#636EFA': '#5F95DE',
37 '#7201A8': '#5D5D7A',
38 '#7FBC41': '#8BD155',
39 '#8E0152': '#444459',
40 '#9C179E': '#5D5D7A',
41 '#A4A4A4': '#A4A4A4',
42 '#AB63FA': '#AC92DE',
43 '#B6E880': '#A0D17B',
44 '#B6ECF3': '#B6ECF3',
45 '#B8E186': '#A0D17B',
46 '#BD3786': '#F77E70',
47 '#C51B7D': '#F77E70',
48 '#C8D4E3': '#B6ECF3',
49 '#D8576B': '#F77E70',
50 '#DE77AE': '#DE958E',
51 '#DE958E': '#DE958E',
52 '#E5ECF6': '#F4F4F4',
53 '#E6F5D0': '#E9EABE',
54 '#EBF0F8': '#F4F4F4',
55 '#ED7953': '#F77E70',
56 '#EF553B': '#F77E70',
57 '#F0F921': '#E8EA7E',
58 '#F1B6DA': '#C98FDE',
59 '#F4F4F4': '#F4F4F4',
60 '#F7F7F7': '#F4F4F4',
61 '#FB9F3A': '#EB9E58',
62 '#FDCA26': '#EB9E58',
63 '#FDE0EF': '#F4F4F4',
64 '#FECB52': '#EB9E58',
65 '#FF6692': '#F77E70',
66 '#FF97FF': '#C98FDE',
67 '#FFA15A': '#EB9E58',
68}
71def conform(data, actions=[], columns=[]):
72 # type: (DataFrame, List[dict], List[str]) -> DataFrame
73 '''
74 Conform given mint transaction data.
76 Args:
77 data (DataFrame): Mint transactions DataFrame.
78 actions (list[dict], optional): List of conform actions. Default: [].
79 columns (list[str], optional): List of columns. Default: [].
81 Raises:
82 DataError: If invalid conform action given.
83 ValueError: If source column not found in data columns.
85 Returns:
86 DataFrame: Conformed DataFrame.
87 '''
88 for action in actions:
89 ConformAction(action).validate()
91 data.rename(lbt.to_snakecase, axis=1, inplace=True)
92 lut = dict(
93 account_name='account',
94 transaction_type='type'
95 )
96 data.rename(lambda x: lut.get(x, x), axis=1, inplace=True)
97 data.date = DatetimeIndex(data.date)
98 data.amount = data.amount.astype(float)
99 data.category = data.category \
100 .apply(lambda x: re.sub('&', 'and', lbt.to_snakecase(x)))
101 data.account = data.account.apply(lbt.to_snakecase)
103 for action in actions:
104 source = action['source_column']
105 if source not in data.columns:
106 msg = f'Source column {source} not found in columns. '
107 msg += f'Legal columns include: {data.columns.tolist()}.'
108 raise ValueError(msg)
110 target = action['target_column']
111 if target not in data.columns:
112 data[target] = None
114 for regex, val in action['mapping'].items():
115 if action['action'] == 'overwrite':
116 mask = data[source] \
117 .apply(lambda x: re.search(regex, x, flags=re.I)).astype(bool)
118 data.loc[mask, target] = val
119 elif action['action'] == 'substitute':
120 data[target] = data[source] \
121 .apply(lambda x: re.sub(regex, val, str(x), flags=re.I))
123 if columns != []:
124 data = data[columns]
125 return data
128def filter_data(data, column, comparator, value):
129 # type: (DataFrame, str, str, Any) -> DataFrame
130 '''
131 Filters given data via comparator(column value, value).
133 Legal comparators:
135 * == ``lambda a, b: a == b``
136 * != ``lambda a, b: a != b``
137 * > ``lambda a, b: a > b``
138 * >= ``lambda a, b: a >= b``
139 * < ``lambda a, b: a < b``
140 * <= ``lambda a, b: a <= b``
141 * ~ ``lambda a, b: bool(re.search(a, b, flags=re.I))``
142 * !~ ``lambda a, b: not bool(re.search(a, b, flags=re.I))``
144 Args:
145 data (DataFrame): DataFrame to be filtered.
146 column (str): Column name.
147 comparator (str): String representation of comparator.
148 value (object): Value to be compared.
150 Raises:
151 EnforceError: If data is not a DataFrame.
152 EnforceError: If column is not a string.
153 EnforceError: If column not in data columns.
154 EnforceError: If illegal comparator given.
155 EnforceError: If comparator is ~ or !~ and value is not a string.
157 Returns:
158 DataFrame: Filtered data.
159 '''
160 Enforce(data, 'instance of', DataFrame)
161 msg = 'Column must be a str. {a} is not str.'
162 Enforce(column, 'instance of', str, message=msg)
163 eft.enforce_columns_in_dataframe([column], data)
165 lut = {
166 '==': lambda a, b: a == b,
167 '!=': lambda a, b: a != b,
168 '>': lambda a, b: a > b,
169 '>=': lambda a, b: a >= b,
170 '<': lambda a, b: a < b,
171 '<=': lambda a, b: a <= b,
172 '~': lambda a, b: bool(re.search(b, a, flags=re.I)),
173 '!~': lambda a, b: not bool(re.search(b, a, flags=re.I)),
174 }
175 msg = 'Illegal comparator. {a} not in [==, !=, >, >=, <, <=, ~, !~].'
176 Enforce(comparator, 'in', lut.keys(), message=msg)
178 if comparator in ['~', '!~']:
179 msg = 'Value must be string if comparator is ~ or !~. {a} is not str.'
180 Enforce(value, 'instance of', str, message=msg)
181 # --------------------------------------------------------------------------
183 op = lut[comparator]
184 mask = data[column].apply(lambda x: op(x, value))
185 data = data[mask]
186 return data
189def group_data(data, columns, metric, datetime_column='date'):
190 # type: (DataFrame, Union[str, List[str]], str, str) -> DataFrame
191 '''
192 Groups given data by given columns according to given metric.
193 If a legal time interval is given in the columns, then an additional special
194 column of that same name is added to the data for grouping.
196 Legal metrics:
198 * max ``lambda x: x.max()``
199 * mean ``lambda x: x.mean()``
200 * min ``lambda x: x.min()``
201 * std ``lambda x: x.std()``
202 * sum ``lambda x: x.sum()``
203 * var ``lambda x: x.var()``
204 * count ``lambda x: x.count()``
206 Legal time intervals:
208 * year
209 * quarter
210 * month
211 * two_week
212 * week
213 * day
214 * hour
215 * half_hour
216 * quarter_hour
217 * minute
218 * second
219 * microsecond
221 Args:
222 data (DataFrame): DataFrame to be grouped.
223 columns (str or list[str]): Columns to group data by.
224 metric (str): String representation of metric.
225 datetime_column (str, optinal): Datetime column for time grouping.
226 Default: date.
228 Raises:
229 EnforceError: If data is not a DataFrame.
230 EnforceError: If columns not in data columns.
231 EnforceError: If illegal metric given.
232 EnforceError: If time interval in columns and datetime_column not in
233 columns.
235 Returns:
236 DataFrame: Grouped data.
237 '''
238 # luts
239 met_lut = {
240 'max': lambda x: x.max(),
241 'mean': lambda x: x.mean(),
242 'min': lambda x: x.min(),
243 'std': lambda x: x.std(),
244 'sum': lambda x: x.sum(),
245 'var': lambda x: x.var(),
246 'count': lambda x: x.count(),
247 }
249 time_lut = {
250 'year': lambda x: dt.datetime(x.year, 1, 1),
251 'quarter': lambda x: dt.datetime(
252 x.year, int(np.ceil(x.month / 3) * 3 - 2), 1
253 ),
254 'month': lambda x: dt.datetime(x.year, x.month, 1),
255 'two_week': lambda x: dt.datetime(
256 x.year, x.month, min(int(np.ceil(x.day / 14) * 14 - 13), 28)
257 ),
258 'week': lambda x: dt.datetime(
259 x.year, x.month, max(1, min([int(x.month / 7) * 7, 28]))
260 ),
261 'day': lambda x: dt.datetime(x.year, x.month, x.day),
262 'hour': lambda x: dt.datetime(x.year, x.month, x.day, x.hour),
263 'half_hour': lambda x: dt.datetime(
264 x.year, x.month, x.day, x.hour, int(x.minute / 30) * 30
265 ),
266 'quarter_hour': lambda x: dt.datetime(
267 x.year, x.month, x.day, x.hour, int(x.minute / 15) * 15
268 ),
269 'minute': lambda x: dt.datetime(
270 x.year, x.month, x.day, x.hour, x.minute
271 ),
272 'second': lambda x: dt.datetime(
273 x.year, x.month, x.day, x.hour, x.minute, x.second
274 ),
275 'microsecond': lambda x: dt.datetime(
276 x.year, x.month, x.day, x.hour, x.minute, x.second, x.microsecond
277 ),
278 }
279 # --------------------------------------------------------------------------
281 # enforcements
282 Enforce(data, 'instance of', DataFrame)
283 columns_ = columns # type: Any
284 if type(columns_) != list:
285 columns_ = [columns_]
287 cols = list(filter(lambda x: x not in time_lut.keys(), columns_))
288 eft.enforce_columns_in_dataframe(cols, data)
290 msg = '{a} is not a legal metric. Legal metrics: {b}.'
291 Enforce(metric, 'in', sorted(list(met_lut.keys())), message=msg)
293 # time column
294 if len(columns_) > len(cols):
295 eft.enforce_columns_in_dataframe([datetime_column], data)
296 msg = 'Datetime column of type {a}, it must be of type {b}.'
297 Enforce(
298 data[datetime_column].dtype.type, '==', np.datetime64, message=msg
299 )
300 # --------------------------------------------------------------------------
302 for col in columns_:
303 if col in time_lut.keys():
304 op = time_lut[col]
305 data[col] = data[datetime_column].apply(op)
306 agg = met_lut[metric]
307 cols = data.columns.tolist()
308 grp = data.groupby(columns_, as_index=False)
309 output = agg(grp)
311 # get first value for columns that cannot be computed by given metric
312 diff = set(cols).difference(output.columns.tolist())
313 if len(diff) > 0:
314 first = grp.first()
315 for col in diff:
316 output[col] = first[col]
317 return output
320def pivot_data(data, columns, values=[], index=None):
321 # type: (DataFrame, List[str], List[str], Optional[str]) -> DataFrame
322 '''
323 Pivots a given dataframe via a list of columns.
325 Legal time columns:
327 * date
328 * year
329 * quarter
330 * month
331 * two_week
332 * week
333 * day
334 * hour
335 * half_hour
336 * quarter_hour
337 * minute
338 * second
339 * microsecond
341 Args:
342 data (DataFrame): DataFrame to be pivoted.
343 columns (list[str]): Columns whose unique values become separate traces
344 within a plot.
345 values (list[str], optional): Columns whose values become the values
346 within each trace of a plot. Default: [].
347 index (str, optional): Column whose values become the y axis values of a
348 plot. Default: None.
350 Raises:
351 EnforceError: If data is not a DataFrame.
352 EnforceError: If data is of zero length.
353 EnforceError: If columns not in data columns.
354 EnforceError: If values not in data columns.
355 EnforceError: If index not in data columns or legal time columns.
357 Returns:
358 DataFrame: Pivoted data.
359 '''
360 time_cols = [
361 'date', 'year', 'quarter', 'month', 'two_week', 'week', 'day', 'hour',
362 'half_hour', 'quarter_hour', 'minute', 'second', 'microsecond',
363 ]
365 Enforce(data, 'instance of', DataFrame)
366 msg = 'DataFrame must be at least 1 in length. Given length: {a}.'
367 Enforce(len(data), '>=', 1, message=msg)
368 eft.enforce_columns_in_dataframe(columns, data)
369 eft.enforce_columns_in_dataframe(values, data)
370 if index is not None:
371 msg = '{a} is not in legal column names: {b}.'
372 Enforce(index, 'in', data.columns.tolist() + time_cols, message=msg)
373 # --------------------------------------------------------------------------
375 vals = copy(values)
376 if index is not None and index not in values:
377 vals.append(index)
379 if index in time_cols:
380 data[index] = data[index] \
381 .apply(lambda x: x + dt.timedelta(microseconds=randint(0, 999999)))
383 data = data.pivot(columns=columns, values=vals, index=index)
384 data = data[values]
385 data.columns = data.columns.droplevel(0)
386 return data
389def get_figure(
390 data, # type: DataFrame
391 filters=[], # type: List[dict]
392 group=None, # type: Optional[dict]
393 pivot=None, # type: Optional[dict]
394 kind='bar', # type: str
395 color_scheme={}, # type: Dict[str, str]
396 x_axis=None, # type: Optional[str]
397 y_axis=None, # type: Optional[str]
398 title=None, # type: Optional[str]
399 x_title=None, # type: Optional[str]
400 y_title=None, # type: Optional[str]
401 bins=50, # type: int
402 bar_mode='stack', # type: str
403):
404 '''
405 Generates a plotly figure dictionary from given data and manipulations.
407 Args:
408 data (DataFrame): Data.
409 filters (list[dict], optional): List of filters for data. Default: [].
410 group (dict, optional): Grouping operation. Default: None.
411 pivot (dict, optional): Pivot operation. Default: None.
412 kind (str, optional): Kind of plot. Default: bar.
413 color_scheme (dict[str, str], optional): Color scheme. Default: {}.
414 x_axis (str): Column to use as x axis: Default: None.
415 y_axis (str): Column to use as y axis: Default: None.
416 title (str, optional): Title of plot. Default: None.
417 x_title (str, optional): Title of x axis. Default: None.
418 y_title (str, optional): Title of y axis. Default: None.
419 bins (int, optional): Number of bins if histogram. Default: 50.
420 bar_mode (str, optional): How bars in bar graph are presented.
421 Default: stack.
423 Raises:
424 DataError: If any filter in filters is invalid.
425 DataError: If group is invalid.
426 DataError: If pivot is invalid.
428 Returns:
429 dict: Plotly Figure as dictionary.
430 '''
431 data = data.copy()
433 # filter
434 for f in filters:
435 f = cfg.FilterAction(f)
436 try:
437 f.validate()
438 except DataError as e:
439 raise DataError({'Invalid filter': e.to_primitive()})
441 f = f.to_primitive()
442 if len(data) == 0:
443 break
444 data = filter_data(data, f['column'], f['comparator'], f['value'])
446 # group
447 if group is not None:
448 grp = group # type: Any
449 grp = cfg.GroupAction(grp)
450 try:
451 grp.validate()
452 except DataError as e:
453 raise DataError({'Invalid group': e.to_primitive()})
454 grp = grp.to_primitive()
456 data = group_data(
457 data,
458 grp['columns'],
459 grp['metric'],
460 datetime_column=grp['datetime_column'],
461 )
463 # pivot
464 if pivot is not None:
465 pvt = pivot # type: Any
466 pvt = cfg.PivotAction(pvt)
467 try:
468 pvt.validate()
469 except DataError as e:
470 raise DataError({'Invalid pivot': e.to_primitive()})
471 pvt = pvt.to_primitive()
473 data = pivot_data(
474 data, pvt['columns'], values=pvt['values'], index=pvt['index']
475 )
477 # create figure
478 figure = data.iplot(
479 kind=kind, asFigure=True, theme='henanigans', colorscale='henanigans',
480 x=x_axis, y=y_axis, title=title, xTitle=x_title, yTitle=y_title,
481 barmode=bar_mode, bins=bins
482 ).to_dict() # type: dict
483 figure['layout']['title']['font']['color'] = '#F4F4F4'
484 figure['layout']['xaxis']['title']['font']['color'] = '#F4F4F4'
485 figure['layout']['yaxis']['title']['font']['color'] = '#F4F4F4'
486 if color_scheme != {}:
487 figure = conform_figure(figure, color_scheme)
489 # makes area traces stackable
490 if kind == 'area':
491 for trace in figure['data']:
492 trace['stackgroup'] = 1
494 return figure
497def parse_rgba(string):
498 '''
499 Parses rgb and rgba strings into tuples of numbers.
501 Example:
502 >>>parse_rgba('rgb(255, 0, 0)')
503 (255, 0, 0)
504 >>>parse_rgba('rgba(255, 0, 0, 0.5)')
505 (255, 0, 0, 0.5)
506 >>>parse_rgba('foo') is None
507 True
509 Args:
510 string (str): String to be parsed.
512 Returns:
513 tuple: (red, green, blue) or (red, green, blue, alpha)
514 '''
515 result = re.search(r'rgba?\((\d+, \d+, \d+(, \d+\.?\d*)?)\)', string)
516 if result is None:
517 return None
519 result = result.group(1)
520 result = re.split(', ', result)
521 if len(result) == 3:
522 result = tuple(map(int, result))
523 return result
525 result = list(map(int, result[:-1])) + [float(result[-1])]
526 result = tuple(result)
527 return result
530def conform_figure(figure, color_scheme):
531 '''
532 Conforms given figure to use given color scheme.
534 Args:
535 figure (dict): Plotly figure.
536 color_scheme (dict): Color scheme dictionary.
538 Returns:
539 dict: Conformed figure.
540 '''
541 # create hex to hex lut
542 lut = {}
543 for key, val in cfg.COLOR_SCHEME.items():
544 if key in color_scheme:
545 lut[val] = color_scheme[key]
547 # rgba? to hex --> coerce to standard colors --> coerce with color_scheme
548 figure = rpb.BlobETL(figure) \
549 .set(
550 predicate=lambda k, v: isinstance(v, str) and 'rgb' in v,
551 value_setter=lambda k, v: webcolors.rgb_to_hex(parse_rgba(v)[:3]).upper()) \
552 .set(
553 predicate=lambda k, v: isinstance(v, str),
554 value_setter=lambda k, v: COLOR_COERCION_LUT.get(v, v)) \
555 .set(
556 predicate=lambda k, v: isinstance(v, str),
557 value_setter=lambda k, v: lut.get(v, v)) \
558 .to_dict()
559 return figure
562# SQL-PARSING-------------------------------------------------------------------
563def get_sql_grammar():
564 '''
565 Creates a grammar for parsing SQL queries.
567 Returns:
568 MatchFirst: SQL parser.
569 '''
570 select = pp.Regex('select', flags=re.I) \
571 .setParseAction(lambda s, _, t: 'select') \
572 .setResultsName('operator')
573 from_ = pp.Suppress(pp.Regex('from', flags=re.I))
574 table = (from_ + pp.Regex('[a-z]+', flags=re.I)) \
575 .setParseAction(lambda s, _, t: t[0]) \
576 .setResultsName('table')
577 regex = pp.Regex('~|regex').setParseAction(lambda s, _, t: '~')
578 not_regex = pp.Regex('!~|not regex').setParseAction(lambda s, _, t: '!~')
579 any_op = pp.Regex('[^ ]*')
580 operator = pp.Or([not_regex, regex, any_op]).setResultsName('operator')
581 quote = pp.Suppress(pp.Optional("'"))
582 value = (quote + pp.Regex('[^\']+', flags=re.I) + quote) \
583 .setResultsName('value') \
584 .setParseAction(lambda s, _, t: t[0])
585 columns = pp.delimitedList(pp.Regex('[^, ]*'), delim=pp.Regex(', *')) \
586 .setResultsName('display_columns')
587 column = pp.Regex('[a-z]+', flags=re.I).setResultsName('column')
588 conditional = column + operator + value
589 head = select + columns + table
590 grammar = head | conditional
591 return grammar
594def query_data(data, query, uri='sqlite:///:memory:'):
595 '''
596 Parses SQL + regex query and applies it to given data.
598 Regex operators:
600 * ~, regex - Match regular expression
601 * !~, not regex - Do not match regular expression
603 Args:
604 data (DataFrame): DataFrame to be queried.
605 query (str): SQL query that may include regex operators.
607 Returns:
608 DataFrame: Data filtered by query.
609 '''
610 # split queries by where/and/or
611 queries = re.split(' where | and | or ', query, flags=re.I)
613 # detect whether any sub query has a regex operator
614 has_regex = False
615 for q in queries:
616 if re.search(' regex | ~ | !~ ', q, flags=re.I):
617 has_regex = True
618 break
620 # if no regex operator is found just submit query to pandasql
621 if not has_regex:
622 data = pandasql.PandaSQL(uri)(query, locals())
624 else:
625 grammar = get_sql_grammar()
627 # move select statement to end
628 if 'select' in queries[0]:
629 q = queries.pop(0)
630 queries.append(q)
632 for q in queries:
633 # get column, operator and value
634 parse = grammar.parseString(q).asDict()
635 op = parse['operator']
637 # initial select statement
638 if op == 'select':
639 data = pandasql.PandaSQL(uri)(q, locals())
641 # regex search
642 elif op == '~':
643 mask = data[parse['column']] \
644 .astype(str) \
645 .apply(lambda x: re.search(parse['value'], x, flags=re.I)) \
646 .astype(bool)
647 data = data[mask]
649 # regex not search
650 elif op == '!~':
651 mask = data[parse['column']] \
652 .astype(str) \
653 .apply(lambda x: re.search(parse['value'], x, flags=re.I)) \
654 .astype(bool)
655 data = data[~mask]
657 # ther SQL query
658 else:
659 data = pandasql.sqldf('select * from data where ' + q, {'data': data})
661 if len(data) == 0:
662 break
663 return data
666def query_dict(data, query):
667 # type: (dict, str) -> dict
668 '''
669 Query a given diction with a given SQL query.
671 Args:
672 data (dict): Dictionary to be queried.
673 query (str): SQL query.
675 Returns:
676 dict: Queried dictionary.
677 '''
678 data_ = data # type: Any
679 data_ = rpb.BlobETL(data_) \
680 .to_flat_dict() \
681 .items()
682 data_ = DataFrame(list(data_), columns=['key', 'value'])
683 data_ = query_data(data_, query)
684 data_ = dict(zip(data_.key.tolist(), data_.value.tolist()))
685 data_ = rpb.BlobETL(data_).to_dict()
686 return data_