Coverage for /home/ubuntu/shekels/python/shekels/enforce/enforce_tools.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-11-15 00:54 +0000

1from typing import Any, List # noqa: F401 

2 

3from lunchbox.enforce import Enforce, EnforceError 

4from pandas import DataFrame 

5# ------------------------------------------------------------------------------ 

6 

7 

8def enforce_dataframes_are_equal(a, b): 

9 ''' 

10 Endsures that DataFrames a and b have equal contents. 

11 

12 Args: 

13 a (DataFrame): DataFrame A. 

14 b (DataFrame): DataFrame B. 

15 

16 Raises: 

17 EnforceError: If a and b are not equal. 

18 ''' 

19 # column names 

20 a_cols = set(a.columns.tolist()) 

21 b_cols = b.columns.tolist() 

22 diff = a_cols.symmetric_difference(b_cols) 

23 diff = sorted(list(diff)) 

24 

25 msg = f'A and b have different columns: {diff}.' 

26 Enforce(len(diff), '==', 0, message=msg) 

27 

28 # shape 

29 msg = 'A and b have different shapes. {a} != {b}.' 

30 Enforce(a.shape, '==', b.shape, message=msg) 

31 

32 # NaNs cannot be compared 

33 a = a.fillna('---NAN---') 

34 b = b.fillna('---NAN---') 

35 

36 # values 

37 errors = [] 

38 for col in a.columns: 

39 mask = a[col] != b[col] 

40 a_vals = a.loc[mask, col].tolist() 

41 if len(a_vals) > 0: 

42 b_vals = b.loc[mask, col].tolist() 

43 error = [[col, av, bv] for av, bv in zip(a_vals, b_vals)] 

44 errors.extend(error) 

45 

46 if len(errors) > 0: 

47 msg = DataFrame(errors, columns=['column', 'a', 'b']).to_string() 

48 msg = f'DatFrames have different values:\n{msg}' 

49 raise EnforceError(msg) 

50 

51 # records 

52 a = a.to_dict(orient='records') 

53 b = b.to_dict(orient='records') 

54 Enforce(a, '==', b) 

55 

56 

57def enforce_columns_in_dataframe(columns, data): 

58 # type: (List[str], DataFrame) -> None 

59 ''' 

60 Ensure all given columns are in given dataframe columns. 

61 

62 Args: 

63 columns (list[str]): Column names. 

64 data (DataFrame): DataFrame. 

65 

66 Raises: 

67 EnforceError: If any column not found in data.columns. 

68 ''' 

69 cols = data.columns.tolist() 

70 diff = set(columns).difference(cols) # type: Any 

71 diff = sorted(list(diff)) 

72 msg = f'Given columns not found in data. {diff} not in {cols}.' 

73 Enforce(diff, '==', [], message=msg)