Coverage for /home/ubuntu/shekels/python/shekels/enforce/enforce_tools.py: 100%
35 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-11-15 00:54 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-11-15 00:54 +0000
1from typing import Any, List # noqa: F401
3from lunchbox.enforce import Enforce, EnforceError
4from pandas import DataFrame
5# ------------------------------------------------------------------------------
8def enforce_dataframes_are_equal(a, b):
9 '''
10 Endsures that DataFrames a and b have equal contents.
12 Args:
13 a (DataFrame): DataFrame A.
14 b (DataFrame): DataFrame B.
16 Raises:
17 EnforceError: If a and b are not equal.
18 '''
19 # column names
20 a_cols = set(a.columns.tolist())
21 b_cols = b.columns.tolist()
22 diff = a_cols.symmetric_difference(b_cols)
23 diff = sorted(list(diff))
25 msg = f'A and b have different columns: {diff}.'
26 Enforce(len(diff), '==', 0, message=msg)
28 # shape
29 msg = 'A and b have different shapes. {a} != {b}.'
30 Enforce(a.shape, '==', b.shape, message=msg)
32 # NaNs cannot be compared
33 a = a.fillna('---NAN---')
34 b = b.fillna('---NAN---')
36 # values
37 errors = []
38 for col in a.columns:
39 mask = a[col] != b[col]
40 a_vals = a.loc[mask, col].tolist()
41 if len(a_vals) > 0:
42 b_vals = b.loc[mask, col].tolist()
43 error = [[col, av, bv] for av, bv in zip(a_vals, b_vals)]
44 errors.extend(error)
46 if len(errors) > 0:
47 msg = DataFrame(errors, columns=['column', 'a', 'b']).to_string()
48 msg = f'DatFrames have different values:\n{msg}'
49 raise EnforceError(msg)
51 # records
52 a = a.to_dict(orient='records')
53 b = b.to_dict(orient='records')
54 Enforce(a, '==', b)
57def enforce_columns_in_dataframe(columns, data):
58 # type: (List[str], DataFrame) -> None
59 '''
60 Ensure all given columns are in given dataframe columns.
62 Args:
63 columns (list[str]): Column names.
64 data (DataFrame): DataFrame.
66 Raises:
67 EnforceError: If any column not found in data.columns.
68 '''
69 cols = data.columns.tolist()
70 diff = set(columns).difference(cols) # type: Any
71 diff = sorted(list(diff))
72 msg = f'Given columns not found in data. {diff} not in {cols}.'
73 Enforce(diff, '==', [], message=msg)