diff --git a/pointblank/_agg.py b/pointblank/_agg.py
new file mode 100644
index 000000000..5d4cf1655
--- /dev/null
+++ b/pointblank/_agg.py
@@ -0,0 +1,112 @@
+from collections.abc import Callable
+from typing import Any
+
+import narwhals as nw
+
+# TODO: Should take any frame type
+Aggregator = Callable[[nw.DataFrame], Any]
+Comparator = Callable[[Any, Any], bool]
+
+AGGREGATOR_REGISTRY: dict[str, Aggregator] = {}
+
+COMPARATOR_REGISTRY: dict[str, Comparator] = {}
+
+
+def register(fn):
+ name: str = fn.__name__
+ if name.startswith("comp_"):
+ COMPARATOR_REGISTRY[name.removeprefix("comp_")] = fn
+ elif name.startswith("agg_"):
+ AGGREGATOR_REGISTRY[name.removeprefix("agg_")] = fn
+ else:
+ raise NotImplementedError # pragma: no cover
+ return fn
+
+
+## Aggregator Functions
+@register
+def agg_sum(column: nw.DataFrame) -> float:
+ return column.select(nw.all().sum()).item()
+
+
+@register
+def agg_avg(column: nw.DataFrame) -> float:
+ return column.select(nw.all().mean()).item()
+
+
+@register
+def agg_sd(column: nw.DataFrame) -> float:
+ return column.select(nw.all().std()).item()
+
+
+## Comparator functions
+@register
+def comp_eq(real: float, lower: float, upper: float) -> bool:
+ if lower == upper:
+ return bool(real == lower)
+ return _generic_between(real, lower, upper)
+
+
+@register
+def comp_gt(real: float, lower: float, upper: float) -> bool:
+ if lower == upper:
+ return bool(real > lower)
+ return bool(real > lower)
+
+
+@register
+def comp_ge(real: Any, lower: float, upper: float) -> bool:
+ if lower == upper:
+ return bool(real >= lower)
+ return bool(real >= lower)
+
+
+@register
+def comp_lt(real: float, lower: float, upper: float) -> bool:
+ if lower == upper:
+ return bool(real < lower)
+ return bool(real < upper)
+
+
+@register
+def comp_le(real: float, lower: float, upper: float) -> bool:
+ if lower == upper:
+ return bool(real <= lower)
+ return bool(real <= upper)
+
+
+def _generic_between(real: Any, lower: Any, upper: Any) -> bool:
+ """Call if comparator needs to check between two values."""
+ return bool(lower <= real <= upper)
+
+
+def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]:
+ """Resolve the assertion name to a valid aggregator
+
+ Args:
+ name (str): The name of the assertion.
+
+ Returns:
+ tuple[Aggregator, Comparator]: _description_
+ """
+ name = name.removeprefix("col_")
+ agg_name, comp_name = name.split("_")[-2:]
+
+ aggregator = AGGREGATOR_REGISTRY.get(agg_name)
+ comparator = COMPARATOR_REGISTRY.get(comp_name)
+
+ if aggregator is None:
+ raise ValueError(f"Aggregator '{agg_name}' not found in registry.")
+
+ if comparator is None:
+ raise ValueError(f"Comparator '{comp_name}' not found in registry.")
+
+ return aggregator, comparator
+
+
+def is_valid_agg(name: str) -> bool:
+ try:
+ resolve_agg_registries(name)
+ return True
+ except ValueError:
+ return False
diff --git a/pointblank/_constants.py b/pointblank/_constants.py
index 39d658460..4678f7f50 100644
--- a/pointblank/_constants.py
+++ b/pointblank/_constants.py
@@ -51,6 +51,7 @@
"tbl_match": "tbl_match",
"conjointly": "conjointly",
"specially": "specially",
+ "col_sum_eq": "sum_eq",
}
COMPARISON_OPERATORS = {
@@ -220,6 +221,48 @@
CROSS_MARK_SPAN = "✗"
SVG_ICONS_FOR_ASSERTION_TYPES = {
+ ## EQ Icons
+ "col_vals_eq": """
+""",
+ "col_sum_eq": """
+""",
+ "col_sd_eq": """
+""",
+ "col_avg_eq": """
+""",
+ ## GT Icons
"col_vals_gt": """
""",
+ "col_sum_gt": """
+""",
+ "col_sd_gt": """
+""",
+ "col_avg_gt": """
+""",
+ ## LT Icons
"col_vals_lt": """
""",
- "col_vals_eq": """
+ "col_sum_lt": """
""",
+ "col_sd_lt": """
+""",
+ "col_avg_lt": """
+""",
+ ## NE Icons
"col_vals_ne": """
""",
+ "col_sum_ne": """
+""",
+ "col_sd_ne": """
+""",
+ "col_avg_ne": """
+""",
+ ## GE Icons
"col_vals_ge": """
""",
+ "col_sum_ge": """
+""",
+ "col_sd_ge": """
+""",
+ "col_avg_ge": """
+""",
+ ## LE Icons
"col_vals_le": """
""",
+ "col_sum_le": """
+""",
+ "col_sd_le": """
+""",
+ "col_avg_le": """
+""",
"col_vals_between": """