Skip to content
Open
75 changes: 57 additions & 18 deletions paimon-python/pypaimon/ray/data_evolution_merge_into.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pypaimon.ray.data_evolution_merge_join import (
build_matched_update_ds,
build_not_matched_insert_ds,
build_self_merge_update_ds,
distributed_update_apply,
distributed_write_collect_msgs,
)
Expand Down Expand Up @@ -53,6 +54,7 @@ class _PrepareCtx:
update_pa_schema: pa.Schema
full_pa_schema: pa.Schema
catalog_options: Dict[str, str]
is_self_merge: bool = False


def merge_into(
Expand Down Expand Up @@ -178,33 +180,45 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
_NormalizedClause(spec=spec, condition=c.condition)
)

source_snapshot_id = None
if isinstance(source, str):
source_snapshot = (
catalog.get_table(source)
.snapshot_manager()
.get_latest_snapshot()
is_self_merge = _is_self_merge(target, source, target_on_cols, source_on_cols)
if is_self_merge and not_matched_specs:
raise ValueError(
"Self-merge (source == target with ON _ROW_ID) does not "
"support WHEN NOT MATCHED clauses."
)
if source_snapshot is not None:
source_snapshot_id = source_snapshot.id

source_ds = _normalize_source(
source, catalog_options, source_snapshot_id=source_snapshot_id,
)
_validate_source_on_cols(source_ds, source_on_cols)
if is_self_merge:
source_ds = None
source_col_names = set(full_target_field_names) | set(source_on_cols)
else:
source_snapshot_id = None
if isinstance(source, str):
source_snapshot = (
catalog.get_table(source)
.snapshot_manager()
.get_latest_snapshot()
)
if source_snapshot is not None:
source_snapshot_id = source_snapshot.id
source_ds = _normalize_source(
source, catalog_options, source_snapshot_id=source_snapshot_id,
)
_validate_source_on_cols(source_ds, source_on_cols)
source_col_names = set(_source_schema_or_raise(source_ds).names)
_validate_source_has_target_cols(
source_ds, matched_specs + not_matched_specs,
source_col_names, matched_specs + not_matched_specs,
)

if has_condition:
from pypaimon.ray.merge_condition import extract_columns
source_names = set(_source_schema_or_raise(source_ds).names)
target_names = set(full_target_field_names)
if is_self_merge:
target_names |= set(target_on_cols)
for c in list(when_matched) + list(when_not_matched):
if c.condition is not None:
for ref in extract_columns(c.condition):
prefix, col = ref.split(".", 1)
if prefix == "s" and col not in source_names:
if prefix == "s" and col not in source_col_names:
raise ValueError(
f"condition references unknown source "
f"column '{col}'"
Expand Down Expand Up @@ -233,10 +247,20 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
update_pa_schema=update_pa_schema,
full_pa_schema=full_pa_schema,
catalog_options=catalog_options,
is_self_merge=is_self_merge,
)
return table, source_ds, matched_specs, not_matched_specs, ctx


def _is_self_merge(target, source, target_on_cols, source_on_cols) -> bool:
from pypaimon.table.special_fields import SpecialFields
row_id_name = SpecialFields.ROW_ID.name
return (isinstance(source, str)
and source == target
and target_on_cols == [row_id_name]
and source_on_cols == [row_id_name])


def _build_datasets(
target, source_ds, matched_specs, not_matched_specs,
ctx: "_PrepareCtx", base_snapshot, num_partitions, ray_remote_args,
Expand All @@ -250,6 +274,22 @@ def _build_datasets(
insert_ds = None
update_cols_union: List[str] = []

if ctx.is_self_merge:
if matched_specs and base_snapshot is not None:
update_cols_union = _union_update_cols(matched_specs)
update_ds = build_self_merge_update_ds(
target_identifier=target,
clauses=matched_specs,
target_field_names=ctx.full_target_field_names,
target_pa_schema=ctx.update_pa_schema,
update_cols=update_cols_union,
catalog_options=ctx.catalog_options,
resolve_target_projection=_resolve_target_projection,
snapshot_id=base_snapshot_id,
ray_remote_args=ray_remote_args,
)
return update_ds, insert_ds, update_cols_union

# Mirror Spark: matched/not-matched run as two independent joins
# (inner / left_anti). One unified left_outer join would force
# joined.materialize() to feed both branches, which can OOM on large merges.
Expand Down Expand Up @@ -561,16 +601,15 @@ def _validate_source_on_cols(source_ds, on: Sequence[str]) -> None:


def _validate_source_has_target_cols(
source_ds,
source_col_names: set,
specs: List[_NormalizedClause],
) -> None:
names = set(_source_schema_or_raise(source_ds).names)
needed = set()
for clause in specs:
for val in clause.spec.values():
if isinstance(val, SourceColumnRef):
needed.add(val.column)
missing = sorted(needed - names)
missing = sorted(needed - source_col_names)
if missing:
raise ValueError(
f"source is missing columns {missing} referenced by SET spec"
Expand Down
192 changes: 140 additions & 52 deletions paimon-python/pypaimon/ray/data_evolution_merge_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pyarrow as pa

from pypaimon.ray.data_evolution_merge_transform import (
SourceColumnRef,
_NormalizedClause,
build_update_schema,
vectorized_insert_transform,
Expand All @@ -40,6 +41,137 @@ def _map_kwargs(
return kwargs


def _build_matched_transform(
clauses: List[_NormalizedClause],
on_map: Dict[str, str],
on_pairs: List[Tuple[str, str]],
update_cols: List[str],
row_id_name: str,
update_schema: pa.Schema,
):
prepared_clauses = []
for clause in clauses:
rewritten = None
if clause.condition is not None:
from pypaimon.ray.merge_condition import (
remap_source_on_keys, rewrite_condition,
)
rewritten = remap_source_on_keys(
rewrite_condition(clause.condition), on_map,
)
prepared_clauses.append((clause.spec, rewritten))

_filter_batch = None
if any(r is not None for _, r in prepared_clauses):
from pypaimon.ray.merge_condition import filter_batch as _filter_batch

def _transform(batch: pa.Table) -> pa.Table:
remaining = batch
parts = []
for spec, rewritten in prepared_clauses:
if remaining.num_rows == 0:
break
if rewritten is not None:
matched = _filter_batch(
remaining, rewritten, _pre_rewritten=True,
)
else:
matched = remaining
if matched.num_rows == 0:
continue
parts.append(vectorized_matched_transform(
matched, spec, on_pairs,
update_cols, row_id_name,
update_schema,
))
if rewritten is not None and matched.num_rows < remaining.num_rows:
not_cond = f"COALESCE(NOT ({rewritten}), TRUE)"
remaining = _filter_batch(
remaining, not_cond, _pre_rewritten=True,
)
else:
remaining = remaining.slice(0, 0)
if not parts:
return update_schema.empty_table()
return pa.concat_tables(parts)

return _transform


def build_self_merge_update_ds(
*,
target_identifier: str,
clauses: List[_NormalizedClause],
target_field_names: Sequence[str],
target_pa_schema: pa.Schema,
update_cols: Sequence[str],
catalog_options: Dict[str, str],
resolve_target_projection,
snapshot_id: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
) -> Tuple:
from pypaimon.ray.ray_paimon import read_paimon
from pypaimon.table.special_fields import SpecialFields

row_id_name = SpecialFields.ROW_ID.name
needed_cols = set(resolve_target_projection(
clauses, [row_id_name], update_cols, target_field_names,
))
for clause in clauses:
for value in clause.spec.values():
if isinstance(value, SourceColumnRef):
needed_cols.add(value.column)
target_set = set(target_field_names)
for clause in clauses:
if clause.condition is not None:
from pypaimon.ray.merge_condition import extract_columns
for ref in extract_columns(clause.condition):
prefix, col = ref.split(".", 1)
if prefix == "s" and col in target_set:
needed_cols.add(col)
projection = [row_id_name] + [
c for c in target_field_names if c in needed_cols
]

target_ds = read_paimon(
target_identifier, catalog_options,
projection=projection, snapshot_id=snapshot_id,
)
update_schema = build_update_schema(target_pa_schema, update_cols, row_id_name)

orig_names = target_ds.schema().names
target_renamed = target_ds.rename_columns(
{c: f"t.{c}" for c in orig_names}
)

def _add_source_aliases(batch: pa.Table) -> pa.Table:
columns = list(batch.columns)
names = list(batch.schema.names)
for orig in orig_names:
if orig == row_id_name:
continue
t_col_name = f"t.{orig}"
if t_col_name in names:
idx = names.index(t_col_name)
columns.append(columns[idx])
names.append(f"s.{orig}")
return pa.table(columns, names=names)

aliased = target_renamed.map_batches(
_add_source_aliases, **_map_kwargs(ray_remote_args),
)

_transform = _build_matched_transform(
clauses,
on_map={row_id_name: row_id_name},
on_pairs=[(row_id_name, row_id_name)],
update_cols=list(update_cols),
row_id_name=row_id_name,
update_schema=update_schema,
)
return aliased.map_batches(_transform, **_map_kwargs(ray_remote_args))


def build_matched_update_ds(
*,
target_identifier: str,
Expand Down Expand Up @@ -87,58 +219,14 @@ def build_matched_update_ds(
right_on=tuple(f"s.{c}" for c in source_on),
)

captured_update_cols = list(update_cols)
captured_row_id_name = row_id_name
captured_on_pairs = list(zip(source_on, target_on))
captured_schema = update_schema

on_map = dict(zip(source_on, target_on))
prepared_clauses = []
for clause in clauses:
rewritten = None
if clause.condition is not None:
from pypaimon.ray.merge_condition import (
remap_source_on_keys, rewrite_condition,
)
rewritten = remap_source_on_keys(
rewrite_condition(clause.condition), on_map,
)
prepared_clauses.append((clause.spec, rewritten))

_filter_batch = None
if any(r is not None for _, r in prepared_clauses):
from pypaimon.ray.merge_condition import filter_batch as _filter_batch

def _transform(batch: pa.Table) -> pa.Table:
remaining = batch
parts = []
for spec, rewritten in prepared_clauses:
if remaining.num_rows == 0:
break
if rewritten is not None:
matched = _filter_batch(
remaining, rewritten, _pre_rewritten=True,
)
else:
matched = remaining
if matched.num_rows == 0:
continue
parts.append(vectorized_matched_transform(
matched, spec, captured_on_pairs,
captured_update_cols, captured_row_id_name,
captured_schema,
))
if rewritten is not None and matched.num_rows < remaining.num_rows:
not_cond = f"COALESCE(NOT ({rewritten}), TRUE)"
remaining = _filter_batch(
remaining, not_cond, _pre_rewritten=True,
)
else:
remaining = remaining.slice(0, 0)
if not parts:
return captured_schema.empty_table()
return pa.concat_tables(parts)

_transform = _build_matched_transform(
clauses,
on_map=dict(zip(source_on, target_on)),
on_pairs=list(zip(source_on, target_on)),
update_cols=list(update_cols),
row_id_name=row_id_name,
update_schema=update_schema,
)
return joined.map_batches(_transform, **_map_kwargs(ray_remote_args))


Expand Down
Loading
Loading