Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from pyiceberg.expressions import (
AlwaysFalse,
And,
BooleanExpression,
EqualTo,
In,
Expand All @@ -33,19 +34,40 @@
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])

if unique_keys.num_rows == 0:
return AlwaysFalse()

if len(join_cols) == 1:
return In(join_cols[0], unique_keys[0].to_pylist())
else:
filters = [
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
]

if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)
return In(join_cols[0], unique_keys.column(join_cols[0]).to_pylist())

# Fold the column that leaves the fewest distinct "prefix" combinations into
# an In(); this minimises the disjunct count regardless of column order.
in_col = min(
join_cols,
key=lambda cand: unique_keys.select([c for c in join_cols if c != cand])
.group_by([c for c in join_cols if c != cand])
.aggregate([])
.num_rows,
)
prefix_cols = [c for c in join_cols if c != in_col]

# The group keys come first (in prefix_cols order) followed by the list aggregate.
# Rename the aggregate to a sentinel so it cannot collide with a join column that
# happens to be named f"{in_col}_list".
in_values_col = "__in_values"
while in_values_col in prefix_cols:
in_values_col += "_"
grouped = unique_keys.group_by(prefix_cols).aggregate([(in_col, "list")]).rename_columns([*prefix_cols, in_values_col])

disjuncts: list[BooleanExpression] = []
for row in grouped.to_pylist():
eqs = [EqualTo(c, row[c]) for c in prefix_cols]
prefix_pred = functools.reduce(operator.and_, eqs) if len(eqs) > 1 else eqs[0]
disjuncts.append(And(prefix_pred, In(in_col, row[in_values_col])))

if len(disjuncts) == 1:
return disjuncts[0]
return Or(*disjuncts)


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
Expand Down
180 changes: 176 additions & 4 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import itertools
from pathlib import PosixPath

import pyarrow as pa
Expand All @@ -23,13 +24,15 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
from pyiceberg.expressions import AlwaysFalse, AlwaysTrue, And, EqualTo, In, Reference
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.expressions.visitors import expression_evaluator
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
from pyiceberg.table import Table, UpsertResult
from pyiceberg.table.snapshots import Operation
from pyiceberg.table.upsert_util import create_match_filter
from pyiceberg.typedef import Record
from pyiceberg.types import IntegerType, NestedField, StringType, StructType
from tests.catalog.test_base import InMemoryCatalog

Expand Down Expand Up @@ -437,10 +440,179 @@ def test_create_match_filter_single_condition() -> None:
schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32()), pa.field("extra", pa.string())])
table = pa.Table.from_pylist(data, schema=schema)
expr = create_match_filter(table, ["order_id", "order_line_id"])
assert expr == And(
EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)),
EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)),
# Be insensitive to left/right operands
op1 = EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101))
op2 = EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1))
assert expr == And(op1, op2) or expr == And(op2, op1)


def _assert_match_filter_selects(data: list[dict[str, int]], join_cols: list[str], schema: Schema) -> None:
"""Assert the filter from ``create_match_filter`` matches exactly the unique source keys.

Rather than asserting a specific expression tree (which is implementation-specific),
this binds the filter and evaluates it against the full cross-product of the values
observed per column. The filter must accept exactly the unique keys present in
``data`` and reject every other combination, so any over- or under-matching
(e.g. a cross-product regression) is caught. This holds for any correct
implementation of ``create_match_filter``.
"""
arrow_schema = schema_to_pyarrow(schema)
table = pa.Table.from_pylist(data, schema=arrow_schema)
expr = create_match_filter(table, join_cols)

field_names = [field.name for field in schema.fields]
expected_keys = {tuple(row[name] for name in field_names) for row in data}
domains = [sorted({row[name] for row in data}) for name in field_names]

evaluate = expression_evaluator(schema, expr, case_sensitive=True)
for candidate in itertools.product(*domains):
key = dict(zip(field_names, candidate, strict=True))
should_match = candidate in expected_keys
verb = "rejected matching" if should_match else "matched non-matching"
assert evaluate(Record(*candidate)) is should_match, f"Filter {expr} {verb} key {key}"


def test_create_match_filter_single_prefix_group() -> None:
"""
Test create_match_filter with multiple key columns whose rows all share a single prefix combination.

The filter must match the (one order_id, many order_line_id) keys and nothing else.
"""
schema = Schema(
NestedField(1, "order_id", IntegerType(), required=True),
NestedField(2, "order_line_id", IntegerType(), required=True),
)
data = [
{"order_id": 101, "order_line_id": 1},
{"order_id": 101, "order_line_id": 2},
{"order_id": 101, "order_line_id": 3},
{"order_id": 101, "order_line_id": 3}, # duplicate
]
_assert_match_filter_selects(data, ["order_id", "order_line_id"], schema)


def test_create_match_filter_multiple_prefix_groups() -> None:
"""
Test create_match_filter with multiple key columns that yield several distinct prefix combinations.

The filter must match exactly the listed composite keys and must NOT match cross-product
combinations that never appear together (e.g. order_id 101 with order_line_id 2).
"""
schema = Schema(
NestedField(1, "order_id", IntegerType(), required=True),
NestedField(2, "order_line_id", IntegerType(), required=True),
)
data = [
{"order_id": 101, "order_line_id": 1},
{"order_id": 102, "order_line_id": 1},
{"order_id": 103, "order_line_id": 1},
{"order_id": 201, "order_line_id": 2},
{"order_id": 202, "order_line_id": 2},
]
_assert_match_filter_selects(data, ["order_id", "order_line_id"], schema)


def test_create_match_filter_single_column() -> None:
"""A single join column collapses to a single In() over the unique values."""
schema = pa.schema([pa.field("order_id", pa.int32())])
table = pa.Table.from_pylist([{"order_id": 1}, {"order_id": 2}, {"order_id": 2}], schema=schema)
assert create_match_filter(table, ["order_id"]) == In("order_id", [1, 2])


def test_create_match_filter_single_column_single_value() -> None:
"""A single unique value collapses the In() down to an EqualTo()."""
schema = pa.schema([pa.field("order_id", pa.int32())])
table = pa.Table.from_pylist([{"order_id": 1}, {"order_id": 1}], schema=schema)
assert create_match_filter(table, ["order_id"]) == EqualTo("order_id", 1)


def test_create_match_filter_empty_input() -> None:
"""An empty source matches nothing (AlwaysFalse), for both single and composite keys."""
schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32())])
empty = pa.Table.from_pylist([], schema=schema)
assert create_match_filter(empty, ["order_id"]) == AlwaysFalse()
assert create_match_filter(empty, ["order_id", "order_line_id"]) == AlwaysFalse()


def test_create_match_filter_three_columns() -> None:
"""
Test create_match_filter with three key columns.

Exercises the multi-column prefix branch where the prefix predicate is an And of two
EqualTo() conjuncts combined with an In() over the folded column.
"""
schema = Schema(
NestedField(1, "a", IntegerType(), required=True),
NestedField(2, "b", IntegerType(), required=True),
NestedField(3, "c", IntegerType(), required=True),
)
data = [
{"a": 1, "b": 1, "c": 1},
{"a": 1, "b": 1, "c": 2},
{"a": 1, "b": 1, "c": 3},
{"a": 2, "b": 9, "c": 5},
{"a": 2, "b": 9, "c": 6},
]
_assert_match_filter_selects(data, ["a", "b", "c"], schema)


def test_create_match_filter_column_named_like_aggregate() -> None:
"""
Regression test for #3509 review feedback.

A join column named ``<in_col>_list`` must not collide with the internal list-aggregation
column used to fold values into an In(). Before the fix this raised a TypeError.
"""
schema = Schema(
NestedField(1, "a", IntegerType(), required=True),
NestedField(2, "a_list", IntegerType(), required=True),
)
data = [
{"a": 1, "a_list": 7},
{"a": 2, "a_list": 7},
{"a": 3, "a_list": 8},
]
_assert_match_filter_selects(data, ["a", "a_list"], schema)


def test_upsert_large_composite_key_does_not_overflow(catalog: Catalog) -> None:
"""
Regression test for #3508: a large multi-column upsert must not overflow PyArrow's
expression canonicalizer when at least one key column is low-cardinality (see #3509).
"""
identifier = "default.test_upsert_large_composite_key"
_drop_table(catalog, identifier)

n = 20_000
schema = pa.schema(
[
pa.field("order_id", pa.int64(), nullable=False),
pa.field("region", pa.string(), nullable=False),
pa.field("amount", pa.int64(), nullable=False),
]
)

def make(order_ids: range, amount: int) -> pa.Table:
# region is intentionally low-cardinality (4 values) so the fix folds order_id into an In().
return pa.Table.from_pylist(
[{"order_id": oid, "region": "ABCD"[oid % 4], "amount": amount} for oid in order_ids],
schema=schema,
)

tbl = catalog.create_table(identifier, schema)
tbl.append(make(range(1, n + 1), amount=1))

# Update the first half (amount changes) and insert a tenth of brand-new keys.
source = pa.concat_tables(
[
make(range(1, n // 2 + 1), amount=2),
make(range(n + 1, n + n // 10 + 1), amount=2),
]
)

res = tbl.upsert(source, join_cols=["order_id", "region"])
assert res.rows_updated == n // 2
assert res.rows_inserted == n // 10


def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
Expand Down