diff --git a/pointblank/_agg.py b/pointblank/_agg.py new file mode 100644 index 000000000..5d4cf1655 --- /dev/null +++ b/pointblank/_agg.py @@ -0,0 +1,112 @@ +from collections.abc import Callable +from typing import Any + +import narwhals as nw + +# TODO: Should take any frame type +Aggregator = Callable[[nw.DataFrame], Any] +Comparator = Callable[[Any, Any], bool] + +AGGREGATOR_REGISTRY: dict[str, Aggregator] = {} + +COMPARATOR_REGISTRY: dict[str, Comparator] = {} + + +def register(fn): + name: str = fn.__name__ + if name.startswith("comp_"): + COMPARATOR_REGISTRY[name.removeprefix("comp_")] = fn + elif name.startswith("agg_"): + AGGREGATOR_REGISTRY[name.removeprefix("agg_")] = fn + else: + raise NotImplementedError # pragma: no cover + return fn + + +## Aggregator Functions +@register +def agg_sum(column: nw.DataFrame) -> float: + return column.select(nw.all().sum()).item() + + +@register +def agg_avg(column: nw.DataFrame) -> float: + return column.select(nw.all().mean()).item() + + +@register +def agg_sd(column: nw.DataFrame) -> float: + return column.select(nw.all().std()).item() + + +## Comparator functions +@register +def comp_eq(real: float, lower: float, upper: float) -> bool: + if lower == upper: + return bool(real == lower) + return _generic_between(real, lower, upper) + + +@register +def comp_gt(real: float, lower: float, upper: float) -> bool: + if lower == upper: + return bool(real > lower) + return bool(real > lower) + + +@register +def comp_ge(real: Any, lower: float, upper: float) -> bool: + if lower == upper: + return bool(real >= lower) + return bool(real >= lower) + + +@register +def comp_lt(real: float, lower: float, upper: float) -> bool: + if lower == upper: + return bool(real < lower) + return bool(real < upper) + + +@register +def comp_le(real: float, lower: float, upper: float) -> bool: + if lower == upper: + return bool(real <= lower) + return bool(real <= upper) + + +def _generic_between(real: Any, lower: Any, upper: Any) -> bool: + """Call if comparator needs to check between two values.""" + return bool(lower <= real <= upper) + + +def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]: + """Resolve the assertion name to a valid aggregator + + Args: + name (str): The name of the assertion. + + Returns: + tuple[Aggregator, Comparator]: _description_ + """ + name = name.removeprefix("col_") + agg_name, comp_name = name.split("_")[-2:] + + aggregator = AGGREGATOR_REGISTRY.get(agg_name) + comparator = COMPARATOR_REGISTRY.get(comp_name) + + if aggregator is None: + raise ValueError(f"Aggregator '{agg_name}' not found in registry.") + + if comparator is None: + raise ValueError(f"Comparator '{comp_name}' not found in registry.") + + return aggregator, comparator + + +def is_valid_agg(name: str) -> bool: + try: + resolve_agg_registries(name) + return True + except ValueError: + return False diff --git a/pointblank/_constants.py b/pointblank/_constants.py index 39d658460..4678f7f50 100644 --- a/pointblank/_constants.py +++ b/pointblank/_constants.py @@ -51,6 +51,7 @@ "tbl_match": "tbl_match", "conjointly": "conjointly", "specially": "specially", + "col_sum_eq": "sum_eq", } COMPARISON_OPERATORS = { @@ -220,6 +221,48 @@ CROSS_MARK_SPAN = "" SVG_ICONS_FOR_ASSERTION_TYPES = { + ## EQ Icons + "col_vals_eq": """ + + col_vals_eq + + + + + + +""", + "col_sum_eq": """ + + col_vals_gt + + + + + + +""", + "col_sd_eq": """ + + col_vals_gt + + + + + + +""", + "col_avg_eq": """ + + col_vals_gt + + + + + + +""", + ## GT Icons "col_vals_gt": """ col_vals_gt @@ -230,6 +273,37 @@ """, + "col_sum_gt": """ + + col_vals_gt + + + + + + +""", + "col_sd_gt": """ + + col_vals_gt + + + + + + +""", + "col_avg_gt": """ + + col_vals_gt + + + + + + +""", + ## LT Icons "col_vals_lt": """ col_vals_lt @@ -240,16 +314,37 @@ """, - "col_vals_eq": """ + "col_sum_lt": """ - col_vals_eq + col_vals_lt - + - + + + +""", + "col_sd_lt": """ + + col_vals_lt + + + + + + +""", + "col_avg_lt": """ + + col_vals_lt + + + + """, + ## NE Icons "col_vals_ne": """ col_vals_ne @@ -260,6 +355,37 @@ """, + "col_sum_ne": """ + + col_vals_ne + + + + + + +""", + "col_sd_ne": """ + + col_vals_ne + + + + + + +""", + "col_avg_ne": """ + + col_vals_ne + + + + + + +""", + ## GE Icons "col_vals_ge": """ col_vals_ge @@ -270,6 +396,37 @@ """, + "col_sum_ge": """ + + col_vals_ge + + + + + + +""", + "col_sd_ge": """ + + col_vals_ge + + + + + + +""", + "col_avg_ge": """ + + col_vals_ge + + + + + + +""", + ## LE Icons "col_vals_le": """ col_vals_le @@ -279,6 +436,36 @@ +""", + "col_sum_le": """ + + col_vals_le + + + + + + +""", + "col_sd_le": """ + + col_vals_le + + + + + + +""", + "col_avg_le": """ + + col_vals_le + + + + + + """, "col_vals_between": """ diff --git a/pointblank/_utils.py b/pointblank/_utils.py index 1af2c4969..fa3a20a8d 100644 --- a/pointblank/_utils.py +++ b/pointblank/_utils.py @@ -11,6 +11,7 @@ from narwhals.typing import FrameT from pointblank._constants import ASSERTION_TYPE_METHOD_MAP, GENERAL_COLUMN_TYPES, IBIS_BACKENDS +from pointblank.column import Column, ColumnLiteral, ColumnSelector, ColumnSelectorNarwhals, col if TYPE_CHECKING: from collections.abc import Mapping @@ -550,6 +551,23 @@ def _column_subset_test_prep( return dfn +_PBUnresolvedColumn = str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals +_PBResolvedColumn = Column | ColumnLiteral | ColumnSelectorNarwhals | list[Column] | list[str] + + +def _resolve_columns(columns: _PBUnresolvedColumn) -> _PBResolvedColumn: + # If `columns` is a ColumnSelector or Narwhals selector, call `col()` on it to later + # resolve the columns + if isinstance(columns, (ColumnSelector, nw.selectors.Selector)): + columns = col(columns) + + # If `columns` is Column value or a string, place it in a list for iteration + if isinstance(columns, (Column, str)): + columns = [columns] + + return columns + + def _get_fn_name() -> str: # Get the current function name fn_name = inspect.currentframe().f_back.f_code.co_name @@ -660,10 +678,10 @@ def _format_to_float_value( def _pivot_to_dict(col_dict: Mapping[str, Any]): # TODO : Type hint and unit test result_dict = {} - for col, sub_dict in col_dict.items(): + for _col, sub_dict in col_dict.items(): for key, value in sub_dict.items(): # add columns fields not present if key not in result_dict: result_dict[key] = [None] * len(col_dict) - result_dict[key][list(col_dict.keys()).index(col)] = value + result_dict[key][list(col_dict.keys()).index(_col)] = value return result_dict diff --git a/pointblank/validate.py b/pointblank/validate.py index 3ad0fe388..874ee02ad 100644 --- a/pointblank/validate.py +++ b/pointblank/validate.py @@ -25,6 +25,7 @@ from importlib_resources import files from narwhals.typing import FrameT +from pointblank._agg import is_valid_agg, resolve_agg_registries from pointblank._constants import ( ASSERTION_TYPE_METHOD_MAP, CHECK_MARK_SPAN, @@ -90,6 +91,8 @@ _is_lib_present, _is_narwhals_table, _is_value_a_df, + _PBUnresolvedColumn, + _resolve_columns, _select_df_lib, ) from pointblank._utils_check_args import ( @@ -133,6 +136,50 @@ "get_validation_summary", ] +from functools import wraps +from typing import Callable, ParamSpec, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") + + +def _register_agg_validations(func: Callable[P, R]) -> Callable[P, R]: + """ + Decorator that handles the standard validation pattern for aggregate validators. + + The decorated function should just be a stub that defines the signature. + """ + + @wraps(func) + def wrapper( + self: Validate, + columns, + value, + tol=0, + thresholds=None, + brief=False, + actions=None, + active=True, + ): + for column in columns: + val_info = _ValidationInfo.from_agg_validator( + assertion_type=func.__name__, # Use the function name + columns=column, + value=value, + tol=tol, + thresholds=self.thresholds if thresholds is None else thresholds, + actions=self.actions if actions is None else actions, + brief=self.brief if brief is None else brief, + active=active, + ) + + self._add_validation(validation_info=val_info) + + return self + + return wrapper + + # Create a thread-local storage for the metadata _action_context = threading.local() @@ -3721,6 +3768,30 @@ class _ValidationInfo: insertion order, ensuring notes appear in a consistent sequence in reports and logs. """ + @classmethod + def from_agg_validator( + cls, + assertion_type: str, + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> _ValidationInfo: + _check_thresholds(thresholds=thresholds) + + return cls( + assertion_type=assertion_type, + column=_resolve_columns(columns), + values={"value": value, "tol": tol}, + thresholds=_normalize_thresholds_creation(thresholds), + brief=_transform_auto_brief(brief=brief), + actions=actions, + active=active, + ) + # Validation plan i: int | None = None i_o: int | None = None @@ -4971,6 +5042,411 @@ def set_tbl( def _repr_html_(self) -> str: return self.get_tabular_report()._repr_html_() # pragma: no cover + @_register_agg_validations + def col_avg_eq( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the avg of the values in a column is equal to some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sd_eq( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the standard deviation of the values in a column is equal to some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sd_gt( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the standard deviation of the values in a column is greater than some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sd_ge( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the standard deviation of the values in a column is greater than or equal to some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sd_lt( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the standard deviation of the values in a column is less than some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sd_le( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the standard deviation of the values in a column is less than or equal to some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_avg_ge( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the avg of the values in a column is greater or equal to some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_avg_gt( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the avg of the values in a column greater than some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_avg_le( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the avg of the values in a column is less than or equal to some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_avg_lt( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the avg of the values in a column is less than `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sum_eq( + self, + # TODO: Public type alias + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + # TODO: type alias this, especially the tuple/dict parts + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the sum of the values in a column is equal to some `value`. + + Args: + columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_ + value (float | Column): _description_ + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sum_gt( + self, + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the values in a column sum to a value greater than some `value`. + + Args: + columns (_PBUnresolvedColumn): _description_ + value (float | Column): _description_ + tol (Tolerance, optional): _description_. Defaults to 0. + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + active (bool, optional): _description_. Defaults to True. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sum_ge( + self, + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the values in a column sum to a value greater or equal than some `value`. + + Args: + columns (_PBUnresolvedColumn): _description_ + value (float | Column): _description_ + tol (Tolerance, optional): _description_. Defaults to 0. + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + active (bool, optional): _description_. Defaults to True. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sum_lt( + self, + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the values in a column sum to a value less than some `value`. + + Args: + columns (_PBUnresolvedColumn): _description_ + value (float | Column): _description_ + tol (Tolerance, optional): _description_. Defaults to 0. + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + active (bool, optional): _description_. Defaults to True. + + Returns: + Validate: _description_ + """ + pass + + @_register_agg_validations + def col_sum_le( + self, + columns: _PBUnresolvedColumn, + value: float | Column, + tol: Tolerance = 0, + thresholds: float | bool | tuple | dict | Thresholds | None = None, + brief: str | bool = False, + actions: Actions | None = None, + active: bool = True, + ) -> Validate: + """Assert the values in a column sum to a value less than or equal to some `value`. + + Args: + columns (_PBUnresolvedColumn): _description_ + value (float | Column): _description_ + tol (Tolerance, optional): _description_. Defaults to 0. + thresholds (float | bool | tuple | dict | Thresholds | None, optional): _description_. Defaults to None. + brief (str | bool, optional): _description_. Defaults to False. + actions (Actions | None, optional): _description_. Defaults to None. + active (bool, optional): _description_. Defaults to True. + + Returns: + Validate: _description_ + """ + pass + def col_vals_gt( self, columns: str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals, @@ -5212,7 +5688,6 @@ def col_vals_gt( - Row 1: `c` is `1` and `b` is `2`. - Row 3: `c` is `2` and `b` is `2`. """ - assertion_type = _get_fn_name() _check_column(column=columns) @@ -5232,14 +5707,7 @@ def col_vals_gt( self.thresholds if thresholds is None else _normalize_thresholds_creation(thresholds) ) - # If `columns` is a ColumnSelector or Narwhals selector, call `col()` on it to later - # resolve the columns - if isinstance(columns, (ColumnSelector, nw.selectors.Selector)): - columns = col(columns) - - # If `columns` is Column value or a string, place it in a list for iteration - if isinstance(columns, (Column, str)): - columns = [columns] + columns = _resolve_columns(columns) # Determine brief to use (global or local) and transform any shorthands of `brief=` brief = self.brief if brief is None else _transform_auto_brief(brief=brief) @@ -12279,7 +12747,7 @@ def interrogate( segment = validation.segments # Get compatible data types for this assertion type - assertion_method = ASSERTION_TYPE_METHOD_MAP[assertion_type] + assertion_method = ASSERTION_TYPE_METHOD_MAP.get(assertion_type, assertion_type) compatible_dtypes = COMPATIBLE_DTYPES.get(assertion_method, []) # Process the `brief` text for the validation step by including template variables to @@ -12760,6 +13228,26 @@ def interrogate( tbl_type=tbl_type, ) + elif is_valid_agg(assertion_type): + agg, comp = resolve_agg_registries(assertion_type) + + # Produce a 1-column Narwhals DataFrame + # TODO: Should be able to take lazy too + vec: nw.DataFrame = nw.from_native(data_tbl_step).select(column) + real = agg(vec) + + target = value["value"] + tol = value["tol"] + lower_bound, upper_bound = _derive_bounds(target, tol) + + result_bool = comp(real, target - lower_bound, target + upper_bound) + + validation.all_passed = result_bool + validation.n = 1 + validation.n_passed = int(result_bool) + validation.n_failed = 1 - result_bool + + results_tbl = None else: raise ValueError( f"Unknown assertion type: {assertion_type}" diff --git a/pyproject.toml b/pyproject.toml index 1da906b57..4d48a489d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ docs = [ [dependency-groups] dev = [ "chatlas>=0.6.1", - "duckdb>=1.2.0,<1.3.3", # Pin to stable versions avoiding 1.4.0+ RecordBatchReader issues + "duckdb>=1.2.0,<1.3.3", # Pin to stable versions avoiding 1.4.0+ RecordBatchReader issues "griffe==0.38.1", "hypothesis>=6.129.2", "ibis-framework[duckdb,mysql,postgres,sqlite]>=9.5.0", diff --git a/tests/test_agg.py b/tests/test_agg.py new file mode 100644 index 000000000..514c0453b --- /dev/null +++ b/tests/test_agg.py @@ -0,0 +1,71 @@ +import pytest + +from pointblank import Validate +import polars as pl + + +@pytest.fixture +def simple_pl() -> pl.DataFrame: + return pl.DataFrame( + { + "a": [1, 1, 1, None], + "b": [2, 2, 2, None], + "c": [3, 3, 3, None], + } + ) + + +@pytest.mark.parametrize( + "tol", + [ + (0, 0), + (1, 1), + (100, 100), + 0, + ], +) +def test_sums_old(tol, simple_pl) -> None: + v = Validate(simple_pl).col_sum_eq("a", 3, tol=tol).interrogate() + + v.assert_below_threshold() + + v.get_tabular_report() + + +# TODO: Expand expression types +# TODO: Expand table types +@pytest.mark.parametrize( + ("method", "vals"), + [ + # Sum -> 3, 6, 9 + ("col_sum_eq", (3, 6, 9)), + ("col_sum_gt", (2, 5, 8)), + ("col_sum_ge", (3, 6, 9)), + ("col_sum_lt", (4, 7, 10)), + ("col_sum_le", (3, 6, 9)), + # Average -> 1, 2, 3 + ("col_avg_eq", (1, 2, 3)), + ("col_avg_gt", (0, 1, 2)), + ("col_avg_ge", (1, 2, 3)), + ("col_avg_lt", (2, 3, 4)), + ("col_avg_le", (1, 2, 3)), + # Standard Deviation -> 0, 0, 0 + ("col_sd_eq", (0, 0, 0)), + ("col_sd_gt", (-1, -1, -1)), + ("col_sd_ge", (0, 0, 0)), + ("col_sd_lt", (1, 1, 1)), + ("col_sd_le", (0, 0, 0)), + ], +) +def test_aggs(simple_pl: pl.DataFrame, method: str, vals: tuple[int, int, int]): + v = Validate(simple_pl) + for col, val in zip(["a", "b", "c"], vals): + getattr(v, method)(col, val) + v = v.interrogate() + + v.assert_below_threshold() + v.get_tabular_report() + + +if __name__ == "__main__": + pytest.main([__file__])