Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,20 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
f"DataFrames, and cannot be used as column names"
) from None

# Step 1: Prepare source index with join keys and a marker index
# Cast to target table schema, so we can do the join
# See: https://github.com/apache/arrow/issues/37542
# Step 1: Prepare source index with join keys and a marker index.
# Cast only join columns to target join-column schema so schema evolution
# (for example, newly added non-key columns) doesn't break the join setup.
source_index = (
source_table.cast(target_table.schema)
.select(join_cols_set)
source_table.select(join_cols)
.cast(pa.schema([target_table.schema.field(col) for col in join_cols]))
.append_column(SOURCE_INDEX_COLUMN_NAME, pa.array(range(len(source_table))))
)

# Step 2: Prepare target index with join keys and a marker
target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))
target_index = target_table.select(join_cols).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))

# Step 3: Perform an inner join to find which rows from source exist in target
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
matching_indices = source_index.join(target_index, keys=join_cols, join_type="inner")

# Step 4: Compare all rows using Python
to_update_indices = []
Expand All @@ -112,7 +112,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols

for key in non_key_cols:
source_val = source_row.column(key)[0].as_py()
target_val = target_row.column(key)[0].as_py()
target_val = target_row.column(key)[0].as_py() if key in target_table.column_names else None
if source_val != target_val:
to_update_indices.append(source_idx)
break
Expand Down
51 changes: 51 additions & 0 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,57 @@ def test_upsert_with_nulls(catalog: Catalog) -> None:
)


def test_upsert_after_schema_add_column(catalog: Catalog) -> None:
identifier = "default.test_upsert_after_schema_add_column"
_drop_table(catalog, identifier)

schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "name", StringType(), required=True),
identifier_field_ids=[1],
)

tbl = catalog.create_table(identifier, schema=schema)

initial = pa.Table.from_pylist(
[{"id": 1, "name": "Alice"}],
schema=pa.schema(
[
pa.field("id", pa.int32(), nullable=False),
pa.field("name", pa.string(), nullable=False),
]
),
)
tbl.append(initial)

with tbl.update_schema() as update_schema:
update_schema.add_column("country", StringType())
tbl = tbl.refresh()

source = pa.Table.from_pylist(
[
{"id": 1, "name": "Alice", "country": "NL"},
{"id": 2, "name": "Bob", "country": "US"},
],
schema=pa.schema(
[
pa.field("id", pa.int32(), nullable=False),
pa.field("name", pa.string(), nullable=False),
pa.field("country", pa.string(), nullable=True),
]
),
)

upd = tbl.upsert(source, ["id"])

assert upd.rows_updated == 1
assert upd.rows_inserted == 1
assert sorted(tbl.scan().to_arrow().to_pylist(), key=lambda row: row["id"]) == [
{"id": 1, "name": "Alice", "country": "NL"},
{"id": 2, "name": "Bob", "country": "US"},
]


def test_transaction(catalog: Catalog) -> None:
"""Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is
rolled back."""
Expand Down